diff --git a/neural_filters/__init__.py b/neural_filters/__init__.py index 0f1a90f336961903332ae1f6c64d5e4e702e2eb0..b62c3857e1da144bd8dee9c4abbe4fac945d6f24 100644 --- a/neural_filters/__init__.py +++ b/neural_filters/__init__.py @@ -1,3 +1,5 @@ +RANDOM_STD = 0 + from .neural_filter import * from .neural_filter_1L import * from .neural_filter_2R import * diff --git a/neural_filters/neural_filter.py b/neural_filters/neural_filter.py index d0c0dad870154f88387a0b30480acbc194e6bb85..e6d641c1a5a1c7cd08142f51cac84e228123df8c 100644 --- a/neural_filters/neural_filter.py +++ b/neural_filters/neural_filter.py @@ -30,6 +30,7 @@ from torch.nn import Parameter from torch.nn import functional as F import numpy as np +from . import RANDOM_STD class NeuralFilter(torch.nn.Module): """ @@ -49,9 +50,13 @@ class NeuralFilter(torch.nn.Module): def reset_parameters(self, init=None): if init is None: - self.bias_forget.data.uniform_(-0.2, 0.2) + self.bias_forget.data.uniform_(-RANDOM_STD, RANDOM_STD) else: - self.bias_forget.data.fill_(init) + if not isinstance(init, np.ndarray): + init = np.array(init, ndmin=1) + + ten_init = torch.from_numpy(init) + self.bias_forget.data.copy_(ten_init) def __repr__(self): s = '{name}({hidden_size})' diff --git a/neural_filters/neural_filter_2CC.py b/neural_filters/neural_filter_2CC.py index 13c8aa566ce62a94287487aee84223d770021c65..113929ed565442f2fdaae0123a712e00c18c2626 100644 --- a/neural_filters/neural_filter_2CC.py +++ b/neural_filters/neural_filter_2CC.py @@ -30,6 +30,8 @@ from torch.nn import Parameter from torch.nn import functional as F import numpy as np +from . import RANDOM_STD + class NeuralFilter2CC(torch.nn.Module): """ @@ -48,17 +50,26 @@ class NeuralFilter2CC(torch.nn.Module): self.reset_parameters() - def reset_parameters(self, init=None): - if init is None: - self.bias_modulus.data.uniform_(-0.2, 0.2) - self.bias_theta.data.uniform_(-0.2, 0.2) + def reset_parameters(self, init_modulus=None, init_theta=None): + if init_modulus is None: + self.bias_modulus.data.uniform_(-RANDOM_STD, RANDOM_STD) + else: - if isinstance(init, tuple): - self.bias_modulus.data.fill_(init[0]) - self.bias_theta.data.fill_(init[1]) - else: - self.bias_theta.data.fill_(init) - self.bias_modulus.data.fill_(init) + if not isinstance(init_modulus, np.ndarray): + init_modulus = np.array(init_modulus, ndmin=1) + + ten_init = torch.from_numpy(init_modulus) + self.bias_modulus.data.copy_(ten_init) + + if init_theta is None: + self.bias_theta.data.uniform_(-RANDOM_STD, RANDOM_STD) + + else: + if not isinstance(init_theta, np.ndarray): + init_theta = np.array(init_theta, ndmin=1) + + ten_init = torch.from_numpy(init_theta) + self.bias_theta.data.copy_(ten_init) def __repr__(self): s = '{name}({hidden_size})' diff --git a/neural_filters/neural_filter_2R.py b/neural_filters/neural_filter_2R.py index d485f5cfe9dae76ab50aa6cad18ebee8d9cc4a28..1779d28e54616c497e60b2d5d77ef7ae185a5dc4 100644 --- a/neural_filters/neural_filter_2R.py +++ b/neural_filters/neural_filter_2R.py @@ -30,6 +30,8 @@ from . import NeuralFilter import torch import numpy as np +from . import RANDOM_STD + class NeuralFilter2R(torch.nn.Module): """ @@ -50,8 +52,8 @@ class NeuralFilter2R(torch.nn.Module): def reset_parameters(self, init=None): if init is None: - self.first_cell.bias_forget.data.uniform_(-0.7, -0.3) - self.second_cell.bias_forget.data.uniform_(0.3, 0.7) + self.first_cell.bias_forget.data.uniform_(-0.5 - RANDOM_STD, -0.5 + RANDOM_STD) + self.second_cell.bias_forget.data.uniform_(0.5 - RANDOM_STD, 0.5 + RANDOM_STD) elif isinstance(init, tuple): self.first_cell.reset_parameters(init[0]) self.second_cell.reset_parameters(init[1])