Skip to content
Snippets Groups Projects
Commit 20aa0846 authored by Francois Marelli's avatar Francois Marelli
Browse files

Dimension bug fix

parent 8ccef496
Branches
Tags
No related merge requests found
from .neural_filters import *
\ No newline at end of file
...@@ -44,8 +44,10 @@ class NeuralFilter1P(torch.nn.Module): ...@@ -44,8 +44,10 @@ class NeuralFilter1P(torch.nn.Module):
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size)) self.weight_in = Parameter(torch.Tensor(hidden_size, input_size))
self.bias_ih = Parameter(torch.Tensor(2 * hidden_size)) self.bias_in = Parameter(torch.Tensor(hidden_size))
self.bias_forget = Parameter(torch.Tensor(hidden_size))
self.reset_parameters() self.reset_parameters()
...@@ -54,29 +56,6 @@ class NeuralFilter1P(torch.nn.Module): ...@@ -54,29 +56,6 @@ class NeuralFilter1P(torch.nn.Module):
for weight in self.parameters(): for weight in self.parameters():
weight.data.uniform_(-stdv, stdv) weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx=None):
if hx is None:
vhx = torch.autograd.Variable(input.data.new(input.size(1),
self.hidden_size
).zero_(), requires_grad=False)
hx = (vhx, vhx)
self.check_forward_input(input)
self.check_forward_hidden(input, hx[0], '[0]')
self.check_forward_hidden(input, hx[1], '[1]')
hidden = hx
output = []
steps = range(input.size(0))
for i in steps:
hidden = self.step(input[i], hidden)
output.append(hidden[0])
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
return output, hidden
def __repr__(self): def __repr__(self):
s = '{name}({input_size}, {hidden_size}' s = '{name}({input_size}, {hidden_size}'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
...@@ -87,29 +66,41 @@ class NeuralFilter1P(torch.nn.Module): ...@@ -87,29 +66,41 @@ class NeuralFilter1P(torch.nn.Module):
"input has inconsistent input_size(-1): got {}, expected {}".format( "input has inconsistent input_size(-1): got {}, expected {}".format(
input.size(1), self.input_size)) input.size(1), self.input_size))
def check_forward_hidden(self, input, hx, hidden_label=''): def check_forward_hidden(self, input, hx):
if input.size(1) != hx.size(0): if input.size(1) != hx.size(0):
raise RuntimeError( raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format( "Input batch size {} doesn't match hidden batch size {}".format(
input.size(1), hidden_label, hx.size(0))) input.size(1), hx.size(0)))
if hx.size(1) != self.hidden_size: if hx.size(1) != self.hidden_size:
raise RuntimeError( raise RuntimeError(
"hidden{} has inconsistent hidden_size: got {}, expected {}".format( "hidden has inconsistent hidden_size: got {}, expected {}".format(
hidden_label, hx.size(1), self.hidden_size)) hx.size(1), self.hidden_size))
def step(self, input, hidden): def step(self, input, hidden):
hx, cx = hidden in_gate = F.linear(input, self.weight_in, self.bias_in)
forgetgate = F.sigmoid(self.bias_forget)
next = (forgetgate * hidden) + in_gate
return next
gates = F.linear(input, self.weight_ih, self.bias_ih) def forward(self, input, hx=None):
forgetgate, cellgate = gates.chunk(2, 1) if hx is None:
hx = torch.autograd.Variable(input.data.new(input.size(1),
self.hidden_size
).zero_(), requires_grad=False)
forgetgate = F.sigmoid(forgetgate) self.check_forward_input(input)
self.check_forward_hidden(input, hx)
cy = (forgetgate * cx) + cellgate hidden = hx
hy = cy
return hy, cy output = []
steps = range(input.size(0))
for i in steps:
hidden = self.step(input[i], hidden)
output.append(hidden)
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
test = NeuralFilter1P(2, 2) return output, hidden
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment