Commit b0ec956f authored by M. François's avatar M. François

tuple hidden

parent b315607e
......@@ -99,32 +99,30 @@ class NeuralFilter2CC(torch.nn.Module):
"hidden has inconsistent hidden_size: got {}, expected {}".format(
hx.size(1), self.hidden_size))
def step(self, input_var, delayed, delayed2, a=None, b=None):
def step(self, input_var, hidden, a=None, b=None):
if a is None or b is None:
modulus = F.sigmoid(self.bias_modulus)
cosangle = F.tanh(self.bias_theta)
a = 2 * cosangle * modulus
b = - modulus.pow(2)
next_state = input_var + a * delayed + b * delayed2
next_state = input_var + a * hidden[0] + b * hidden[1]
return next_state
def forward(self, input_var, delayed=None, delayed2=None):
if delayed is None:
delayed = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
def forward(self, input_var, hidden=(None, None)):
h0, h1 = hidden
if h0 is None:
h0 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
requires_grad=False)
if delayed2 is None:
delayed2 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
if h1 is None:
h1 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
requires_grad=False)
self.check_forward_input(input_var)
self.check_forward_hidden(input_var, delayed)
self.check_forward_hidden(input_var, delayed2)
d1 = delayed
d2 = delayed2
self.check_forward_hidden(input_var, h0)
self.check_forward_hidden(input_var, h1)
# do not recompute this at each step to gain efficiency
modulus = F.sigmoid(self.bias_modulus)
......@@ -135,13 +133,13 @@ class NeuralFilter2CC(torch.nn.Module):
output = []
steps = range(input_var.size(0))
for i in steps:
next_state = self.step(input_var[i], d1, d2, a=a, b=b)
next_state = self.step(input_var[i], (h0, h1), a=a, b=b)
output.append(next_state)
d2, d1 = d1, next_state
h1, h0 = h0, next_state
output = torch.cat(output, 0).view(input_var.size(0), *output[0].size())
return output, d1, d2
return output, (h0, h1)
def print_param(self):
modulus = F.sigmoid(self.bias_modulus)
......
......@@ -52,16 +52,11 @@ class NeuralFilter2CD(torch.nn.Module):
s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input_var, hx=None):
if hx is None:
hx = torch.autograd.Variable(input_var.data.new(input_var.size(1),
self.hidden_size
).zero_(), requires_grad=False)
def forward(self, input_var, hx=(None, None)):
inter, inter_hidden = self.cell(input_var, hx[0])
output, hidden = self.cell(inter, hx[1])
inter, inter_hidden = self.cell(input_var, hx)
output, hidden = self.cell(inter)
return output, hidden
return output, (inter_hidden, hidden)
@property
def denominator(self):
......
......@@ -64,16 +64,11 @@ class NeuralFilter2R(torch.nn.Module):
s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input_var, hx=None):
if hx is None:
hx = torch.autograd.Variable(input_var.data.new(input_var.size(1),
self.hidden_size
).zero_(), requires_grad=False)
def forward(self, input_var, hx=(None, None)):
interm, interm_hidden = self.first_cell(input_var, hx[0])
output, hidden = self.second_cell(interm, hx[1])
interm, interm_hidden = self.first_cell(input_var, hx)
output, hidden = self.second_cell(interm)
return output, hidden
return output, (interm_hidden, hidden)
@property
def denominator(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment