Commit b0ec956f authored by M. François's avatar M. François

tuple hidden

parent b315607e
...@@ -99,32 +99,30 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -99,32 +99,30 @@ class NeuralFilter2CC(torch.nn.Module):
"hidden has inconsistent hidden_size: got {}, expected {}".format( "hidden has inconsistent hidden_size: got {}, expected {}".format(
hx.size(1), self.hidden_size)) hx.size(1), self.hidden_size))
def step(self, input_var, delayed, delayed2, a=None, b=None): def step(self, input_var, hidden, a=None, b=None):
if a is None or b is None: if a is None or b is None:
modulus = F.sigmoid(self.bias_modulus) modulus = F.sigmoid(self.bias_modulus)
cosangle = F.tanh(self.bias_theta) cosangle = F.tanh(self.bias_theta)
a = 2 * cosangle * modulus a = 2 * cosangle * modulus
b = - modulus.pow(2) b = - modulus.pow(2)
next_state = input_var + a * delayed + b * delayed2 next_state = input_var + a * hidden[0] + b * hidden[1]
return next_state return next_state
def forward(self, input_var, delayed=None, delayed2=None): def forward(self, input_var, hidden=(None, None)):
if delayed is None: h0, h1 = hidden
delayed = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(), if h0 is None:
h0 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
requires_grad=False) requires_grad=False)
if delayed2 is None: if h1 is None:
delayed2 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(), h1 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
requires_grad=False) requires_grad=False)
self.check_forward_input(input_var) self.check_forward_input(input_var)
self.check_forward_hidden(input_var, delayed) self.check_forward_hidden(input_var, h0)
self.check_forward_hidden(input_var, delayed2) self.check_forward_hidden(input_var, h1)
d1 = delayed
d2 = delayed2
# do not recompute this at each step to gain efficiency # do not recompute this at each step to gain efficiency
modulus = F.sigmoid(self.bias_modulus) modulus = F.sigmoid(self.bias_modulus)
...@@ -135,13 +133,13 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -135,13 +133,13 @@ class NeuralFilter2CC(torch.nn.Module):
output = [] output = []
steps = range(input_var.size(0)) steps = range(input_var.size(0))
for i in steps: for i in steps:
next_state = self.step(input_var[i], d1, d2, a=a, b=b) next_state = self.step(input_var[i], (h0, h1), a=a, b=b)
output.append(next_state) output.append(next_state)
d2, d1 = d1, next_state h1, h0 = h0, next_state
output = torch.cat(output, 0).view(input_var.size(0), *output[0].size()) output = torch.cat(output, 0).view(input_var.size(0), *output[0].size())
return output, d1, d2 return output, (h0, h1)
def print_param(self): def print_param(self):
modulus = F.sigmoid(self.bias_modulus) modulus = F.sigmoid(self.bias_modulus)
......
...@@ -52,16 +52,11 @@ class NeuralFilter2CD(torch.nn.Module): ...@@ -52,16 +52,11 @@ class NeuralFilter2CD(torch.nn.Module):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input_var, hx=None): def forward(self, input_var, hx=(None, None)):
if hx is None: inter, inter_hidden = self.cell(input_var, hx[0])
hx = torch.autograd.Variable(input_var.data.new(input_var.size(1), output, hidden = self.cell(inter, hx[1])
self.hidden_size
).zero_(), requires_grad=False)
inter, inter_hidden = self.cell(input_var, hx) return output, (inter_hidden, hidden)
output, hidden = self.cell(inter)
return output, hidden
@property @property
def denominator(self): def denominator(self):
......
...@@ -64,16 +64,11 @@ class NeuralFilter2R(torch.nn.Module): ...@@ -64,16 +64,11 @@ class NeuralFilter2R(torch.nn.Module):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input_var, hx=None): def forward(self, input_var, hx=(None, None)):
if hx is None: interm, interm_hidden = self.first_cell(input_var, hx[0])
hx = torch.autograd.Variable(input_var.data.new(input_var.size(1), output, hidden = self.second_cell(interm, hx[1])
self.hidden_size
).zero_(), requires_grad=False)
interm, interm_hidden = self.first_cell(input_var, hx) return output, (interm_hidden, hidden)
output, hidden = self.second_cell(interm)
return output, hidden
@property @property
def denominator(self): def denominator(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment