From eebe3e0c6f6164e6776584579fe46e25b9cc5b5a Mon Sep 17 00:00:00 2001 From: Francois Marelli Date: Wed, 28 Feb 2018 16:48:50 +0100 Subject: [PATCH] Better initialization --- neural_filters/NeuralFilter1P.py | 4 ++-- neural_filters/NeuralFilter2CC.py | 15 ++++++++++++--- neural_filters/NeuralFilter2CD.py | 3 +++ neural_filters/NeuralFilter2R.py | 10 ++++++++++ neural_filters/NeuralFilterCell.py | 7 +++++-- 5 files changed, 32 insertions(+), 7 deletions(-) diff --git a/neural_filters/NeuralFilter1P.py b/neural_filters/NeuralFilter1P.py index 403d1f7..658e4fb 100644 --- a/neural_filters/NeuralFilter1P.py +++ b/neural_filters/NeuralFilter1P.py @@ -53,12 +53,12 @@ class NeuralFilter1P(NeuralFilterCell): s = '{name}({input_size},{hidden_size})' return s.format(name=self.__class__.__name__, **self.__dict__) - def reset_parameters(self): + def reset_parameters(self, init=None): stdv = 1.0 / math.sqrt(self.hidden_size) for weight in self.parameters(): weight.data.uniform_(-stdv, stdv) - super(NeuralFilter1P, self).reset_parameters() + super(NeuralFilter1P, self).reset_parameters(init) def check_forward_input(self, input): if input.size(-1) != self.input_size: diff --git a/neural_filters/NeuralFilter2CC.py b/neural_filters/NeuralFilter2CC.py index b633ffc..26a66a3 100644 --- a/neural_filters/NeuralFilter2CC.py +++ b/neural_filters/NeuralFilter2CC.py @@ -48,9 +48,18 @@ class NeuralFilter2CC(torch.nn.Module): self.reset_parameters() - def reset_parameters(self): - self.bias_modulus.data.zero_() - self.bias_theta.data.zero_() + def reset_parameters(self, init=None): + if init is None: + self.bias_modulus.data.zero_() + self.bias_theta.data.zero_() + else: + if isinstance(init, tuple): + self.bias_modulus.data.fill_(init[0]) + self.bias_theta.data.fill_(init[1]) + + else: + self.bias_theta.data.fill_(init) + self.bias_modulus.data.fill_(init) def __repr__(self): s = '{name}({hidden_size})' diff --git a/neural_filters/NeuralFilter2CD.py b/neural_filters/NeuralFilter2CD.py index fd5f087..7605773 100644 --- a/neural_filters/NeuralFilter2CD.py +++ b/neural_filters/NeuralFilter2CD.py @@ -44,6 +44,9 @@ class NeuralFilter2CD (torch.nn.Module): self.cell = NeuralFilterCell(self.hidden_size) + def reset_parameters(self, init=None): + self.cell.reset_parameters(init) + def __repr__(self): s = '{name}({hidden_size})' return s.format(name=self.__class__.__name__, **self.__dict__) diff --git a/neural_filters/NeuralFilter2R.py b/neural_filters/NeuralFilter2R.py index 51e9269..3277bc0 100644 --- a/neural_filters/NeuralFilter2R.py +++ b/neural_filters/NeuralFilter2R.py @@ -45,6 +45,16 @@ class NeuralFilter2R (torch.nn.Module): self.first_cell = NeuralFilterCell(self.hidden_size) self.second_cell = NeuralFilterCell(self.hidden_size) + self.reset_parameters((-0.5, 0.5)) + + def reset_parameters(self, init=None): + if isinstance(init, tuple): + self.first_cell.reset_parameters(init[0]) + self.second_cell.reset_parameters(init[1]) + else: + self.first_cell.reset_parameters(init) + self.second_cell.reset_parameters(init) + def __repr__(self): s = '{name}({hidden_size})' return s.format(name=self.__class__.__name__, **self.__dict__) diff --git a/neural_filters/NeuralFilterCell.py b/neural_filters/NeuralFilterCell.py index adf8060..2510d96 100644 --- a/neural_filters/NeuralFilterCell.py +++ b/neural_filters/NeuralFilterCell.py @@ -47,8 +47,11 @@ class NeuralFilterCell(torch.nn.Module): self.reset_parameters() - def reset_parameters(self): - self.bias_forget.data.zero_() + def reset_parameters(self, init=None): + if init is None: + self.bias_forget.data.zero_() + else: + self.bias_forget.data.fill_(init) def __repr__(self): s = '{name}({hidden_size})' -- 2.21.0