Skip to content
Snippets Groups Projects
Commit 7531241e authored by Francois Marelli's avatar Francois Marelli
Browse files

flexible initialization of weights

parent 7facd573
No related branches found
No related tags found
No related merge requests found
RANDOM_STD = 0
from .neural_filter import *
from .neural_filter_1L import *
from .neural_filter_2R import *
......
......@@ -30,6 +30,7 @@ from torch.nn import Parameter
from torch.nn import functional as F
import numpy as np
from . import RANDOM_STD
class NeuralFilter(torch.nn.Module):
"""
......@@ -49,9 +50,13 @@ class NeuralFilter(torch.nn.Module):
def reset_parameters(self, init=None):
if init is None:
self.bias_forget.data.uniform_(-0.2, 0.2)
self.bias_forget.data.uniform_(-RANDOM_STD, RANDOM_STD)
else:
self.bias_forget.data.fill_(init)
if not isinstance(init, np.ndarray):
init = np.array(init, ndmin=1)
ten_init = torch.from_numpy(init)
self.bias_forget.data.copy_(ten_init)
def __repr__(self):
s = '{name}({hidden_size})'
......
......@@ -30,6 +30,8 @@ from torch.nn import Parameter
from torch.nn import functional as F
import numpy as np
from . import RANDOM_STD
class NeuralFilter2CC(torch.nn.Module):
"""
......@@ -48,17 +50,26 @@ class NeuralFilter2CC(torch.nn.Module):
self.reset_parameters()
def reset_parameters(self, init=None):
if init is None:
self.bias_modulus.data.uniform_(-0.2, 0.2)
self.bias_theta.data.uniform_(-0.2, 0.2)
def reset_parameters(self, init_modulus=None, init_theta=None):
if init_modulus is None:
self.bias_modulus.data.uniform_(-RANDOM_STD, RANDOM_STD)
else:
if isinstance(init, tuple):
self.bias_modulus.data.fill_(init[0])
self.bias_theta.data.fill_(init[1])
else:
self.bias_theta.data.fill_(init)
self.bias_modulus.data.fill_(init)
if not isinstance(init_modulus, np.ndarray):
init_modulus = np.array(init_modulus, ndmin=1)
ten_init = torch.from_numpy(init_modulus)
self.bias_modulus.data.copy_(ten_init)
if init_theta is None:
self.bias_theta.data.uniform_(-RANDOM_STD, RANDOM_STD)
else:
if not isinstance(init_theta, np.ndarray):
init_theta = np.array(init_theta, ndmin=1)
ten_init = torch.from_numpy(init_theta)
self.bias_theta.data.copy_(ten_init)
def __repr__(self):
s = '{name}({hidden_size})'
......
......@@ -30,6 +30,8 @@ from . import NeuralFilter
import torch
import numpy as np
from . import RANDOM_STD
class NeuralFilter2R(torch.nn.Module):
"""
......@@ -50,8 +52,8 @@ class NeuralFilter2R(torch.nn.Module):
def reset_parameters(self, init=None):
if init is None:
self.first_cell.bias_forget.data.uniform_(-0.7, -0.3)
self.second_cell.bias_forget.data.uniform_(0.3, 0.7)
self.first_cell.bias_forget.data.uniform_(-0.5 - RANDOM_STD, -0.5 + RANDOM_STD)
self.second_cell.bias_forget.data.uniform_(0.5 - RANDOM_STD, 0.5 + RANDOM_STD)
elif isinstance(init, tuple):
self.first_cell.reset_parameters(init[0])
self.second_cell.reset_parameters(init[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment