Commit eebe3e0c authored by Francois Marelli's avatar Francois Marelli

Better initialization

parent dc3243af
......@@ -53,12 +53,12 @@ class NeuralFilter1P(NeuralFilterCell):
s = '{name}({input_size},{hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)
def reset_parameters(self):
def reset_parameters(self, init=None):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
super(NeuralFilter1P, self).reset_parameters()
super(NeuralFilter1P, self).reset_parameters(init)
def check_forward_input(self, input):
if input.size(-1) != self.input_size:
......
......@@ -48,9 +48,18 @@ class NeuralFilter2CC(torch.nn.Module):
self.reset_parameters()
def reset_parameters(self):
def reset_parameters(self, init=None):
if init is None:
self.bias_modulus.data.zero_()
self.bias_theta.data.zero_()
else:
if isinstance(init, tuple):
self.bias_modulus.data.fill_(init[0])
self.bias_theta.data.fill_(init[1])
else:
self.bias_theta.data.fill_(init)
self.bias_modulus.data.fill_(init)
def __repr__(self):
s = '{name}({hidden_size})'
......
......@@ -44,6 +44,9 @@ class NeuralFilter2CD (torch.nn.Module):
self.cell = NeuralFilterCell(self.hidden_size)
def reset_parameters(self, init=None):
self.cell.reset_parameters(init)
def __repr__(self):
s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)
......
......@@ -45,6 +45,16 @@ class NeuralFilter2R (torch.nn.Module):
self.first_cell = NeuralFilterCell(self.hidden_size)
self.second_cell = NeuralFilterCell(self.hidden_size)
self.reset_parameters((-0.5, 0.5))
def reset_parameters(self, init=None):
if isinstance(init, tuple):
self.first_cell.reset_parameters(init[0])
self.second_cell.reset_parameters(init[1])
else:
self.first_cell.reset_parameters(init)
self.second_cell.reset_parameters(init)
def __repr__(self):
s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)
......
......@@ -47,8 +47,11 @@ class NeuralFilterCell(torch.nn.Module):
self.reset_parameters()
def reset_parameters(self):
def reset_parameters(self, init=None):
if init is None:
self.bias_forget.data.zero_()
else:
self.bias_forget.data.fill_(init)
def __repr__(self):
s = '{name}({hidden_size})'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment