Commit eebe3e0c authored by Francois Marelli's avatar Francois Marelli

Better initialization

parent dc3243af
...@@ -53,12 +53,12 @@ class NeuralFilter1P(NeuralFilterCell): ...@@ -53,12 +53,12 @@ class NeuralFilter1P(NeuralFilterCell):
s = '{name}({input_size},{hidden_size})' s = '{name}({input_size},{hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def reset_parameters(self): def reset_parameters(self, init=None):
stdv = 1.0 / math.sqrt(self.hidden_size) stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters(): for weight in self.parameters():
weight.data.uniform_(-stdv, stdv) weight.data.uniform_(-stdv, stdv)
super(NeuralFilter1P, self).reset_parameters() super(NeuralFilter1P, self).reset_parameters(init)
def check_forward_input(self, input): def check_forward_input(self, input):
if input.size(-1) != self.input_size: if input.size(-1) != self.input_size:
......
...@@ -48,9 +48,18 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -48,9 +48,18 @@ class NeuralFilter2CC(torch.nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self, init=None):
self.bias_modulus.data.zero_() if init is None:
self.bias_theta.data.zero_() self.bias_modulus.data.zero_()
self.bias_theta.data.zero_()
else:
if isinstance(init, tuple):
self.bias_modulus.data.fill_(init[0])
self.bias_theta.data.fill_(init[1])
else:
self.bias_theta.data.fill_(init)
self.bias_modulus.data.fill_(init)
def __repr__(self): def __repr__(self):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
......
...@@ -44,6 +44,9 @@ class NeuralFilter2CD (torch.nn.Module): ...@@ -44,6 +44,9 @@ class NeuralFilter2CD (torch.nn.Module):
self.cell = NeuralFilterCell(self.hidden_size) self.cell = NeuralFilterCell(self.hidden_size)
def reset_parameters(self, init=None):
self.cell.reset_parameters(init)
def __repr__(self): def __repr__(self):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
......
...@@ -45,6 +45,16 @@ class NeuralFilter2R (torch.nn.Module): ...@@ -45,6 +45,16 @@ class NeuralFilter2R (torch.nn.Module):
self.first_cell = NeuralFilterCell(self.hidden_size) self.first_cell = NeuralFilterCell(self.hidden_size)
self.second_cell = NeuralFilterCell(self.hidden_size) self.second_cell = NeuralFilterCell(self.hidden_size)
self.reset_parameters((-0.5, 0.5))
def reset_parameters(self, init=None):
if isinstance(init, tuple):
self.first_cell.reset_parameters(init[0])
self.second_cell.reset_parameters(init[1])
else:
self.first_cell.reset_parameters(init)
self.second_cell.reset_parameters(init)
def __repr__(self): def __repr__(self):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
......
...@@ -47,8 +47,11 @@ class NeuralFilterCell(torch.nn.Module): ...@@ -47,8 +47,11 @@ class NeuralFilterCell(torch.nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self, init=None):
self.bias_forget.data.zero_() if init is None:
self.bias_forget.data.zero_()
else:
self.bias_forget.data.fill_(init)
def __repr__(self): def __repr__(self):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment