Commit 7531241e authored by Francois Marelli's avatar Francois Marelli

flexible initialization of weights

parent 7facd573
RANDOM_STD = 0
from .neural_filter import * from .neural_filter import *
from .neural_filter_1L import * from .neural_filter_1L import *
from .neural_filter_2R import * from .neural_filter_2R import *
......
...@@ -30,6 +30,7 @@ from torch.nn import Parameter ...@@ -30,6 +30,7 @@ from torch.nn import Parameter
from torch.nn import functional as F from torch.nn import functional as F
import numpy as np import numpy as np
from . import RANDOM_STD
class NeuralFilter(torch.nn.Module): class NeuralFilter(torch.nn.Module):
""" """
...@@ -49,9 +50,13 @@ class NeuralFilter(torch.nn.Module): ...@@ -49,9 +50,13 @@ class NeuralFilter(torch.nn.Module):
def reset_parameters(self, init=None): def reset_parameters(self, init=None):
if init is None: if init is None:
self.bias_forget.data.uniform_(-0.2, 0.2) self.bias_forget.data.uniform_(-RANDOM_STD, RANDOM_STD)
else: else:
self.bias_forget.data.fill_(init) if not isinstance(init, np.ndarray):
init = np.array(init, ndmin=1)
ten_init = torch.from_numpy(init)
self.bias_forget.data.copy_(ten_init)
def __repr__(self): def __repr__(self):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
......
...@@ -30,6 +30,8 @@ from torch.nn import Parameter ...@@ -30,6 +30,8 @@ from torch.nn import Parameter
from torch.nn import functional as F from torch.nn import functional as F
import numpy as np import numpy as np
from . import RANDOM_STD
class NeuralFilter2CC(torch.nn.Module): class NeuralFilter2CC(torch.nn.Module):
""" """
...@@ -48,17 +50,26 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -48,17 +50,26 @@ class NeuralFilter2CC(torch.nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self, init=None): def reset_parameters(self, init_modulus=None, init_theta=None):
if init is None: if init_modulus is None:
self.bias_modulus.data.uniform_(-0.2, 0.2) self.bias_modulus.data.uniform_(-RANDOM_STD, RANDOM_STD)
self.bias_theta.data.uniform_(-0.2, 0.2)
else: else:
if isinstance(init, tuple): if not isinstance(init_modulus, np.ndarray):
self.bias_modulus.data.fill_(init[0]) init_modulus = np.array(init_modulus, ndmin=1)
self.bias_theta.data.fill_(init[1])
else: ten_init = torch.from_numpy(init_modulus)
self.bias_theta.data.fill_(init) self.bias_modulus.data.copy_(ten_init)
self.bias_modulus.data.fill_(init)
if init_theta is None:
self.bias_theta.data.uniform_(-RANDOM_STD, RANDOM_STD)
else:
if not isinstance(init_theta, np.ndarray):
init_theta = np.array(init_theta, ndmin=1)
ten_init = torch.from_numpy(init_theta)
self.bias_theta.data.copy_(ten_init)
def __repr__(self): def __repr__(self):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
......
...@@ -30,6 +30,8 @@ from . import NeuralFilter ...@@ -30,6 +30,8 @@ from . import NeuralFilter
import torch import torch
import numpy as np import numpy as np
from . import RANDOM_STD
class NeuralFilter2R(torch.nn.Module): class NeuralFilter2R(torch.nn.Module):
""" """
...@@ -50,8 +52,8 @@ class NeuralFilter2R(torch.nn.Module): ...@@ -50,8 +52,8 @@ class NeuralFilter2R(torch.nn.Module):
def reset_parameters(self, init=None): def reset_parameters(self, init=None):
if init is None: if init is None:
self.first_cell.bias_forget.data.uniform_(-0.7, -0.3) self.first_cell.bias_forget.data.uniform_(-0.5 - RANDOM_STD, -0.5 + RANDOM_STD)
self.second_cell.bias_forget.data.uniform_(0.3, 0.7) self.second_cell.bias_forget.data.uniform_(0.5 - RANDOM_STD, 0.5 + RANDOM_STD)
elif isinstance(init, tuple): elif isinstance(init, tuple):
self.first_cell.reset_parameters(init[0]) self.first_cell.reset_parameters(init[0])
self.second_cell.reset_parameters(init[1]) self.second_cell.reset_parameters(init[1])
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment