Commit 7facd573 authored by Francois Marelli's avatar Francois Marelli

Random init + efficiency computation

parent 6322cf95
...@@ -49,7 +49,7 @@ class NeuralFilter(torch.nn.Module): ...@@ -49,7 +49,7 @@ class NeuralFilter(torch.nn.Module):
def reset_parameters(self, init=None): def reset_parameters(self, init=None):
if init is None: if init is None:
self.bias_forget.data.zero_() self.bias_forget.data.uniform_(-0.2, 0.2)
else: else:
self.bias_forget.data.fill_(init) self.bias_forget.data.fill_(init)
...@@ -57,46 +57,51 @@ class NeuralFilter(torch.nn.Module): ...@@ -57,46 +57,51 @@ class NeuralFilter(torch.nn.Module):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def check_forward_input(self, input): def check_forward_input(self, input_state):
if input.size(-1) != self.hidden_size: if input_state.size(-1) != self.hidden_size:
raise RuntimeError( raise RuntimeError(
"input has inconsistent input_size(-1): got {}, expected {}".format( "input has inconsistent input_size(-1): got {}, expected {}".format(
input.size(1), self.hidden_size)) input_state.size(1), self.hidden_size))
def check_forward_hidden(self, input, hx): def check_forward_hidden(self, input_state, hx):
if input.size(1) != hx.size(0): if input_state.size(1) != hx.size(0):
raise RuntimeError( raise RuntimeError(
"Input batch size {} doesn't match hidden batch size {}".format( "Input batch size {} doesn't match hidden batch size {}".format(
input.size(1), hx.size(0))) input_state.size(1), hx.size(0)))
if hx.size(1) != self.hidden_size: if hx.size(1) != self.hidden_size:
raise RuntimeError( raise RuntimeError(
"hidden has inconsistent hidden_size: got {}, expected {}".format( "hidden has inconsistent hidden_size: got {}, expected {}".format(
hx.size(1), self.hidden_size)) hx.size(1), self.hidden_size))
def step(self, input, hidden): def step(self, input_state, hidden, a=None):
forgetgate = F.sigmoid(self.bias_forget) if a is None:
next = (forgetgate * hidden) + input a = F.sigmoid(self.bias_forget)
return next
def forward(self, input, hx=None): next_state = (a * hidden) + input_state
return next_state
def forward(self, input_state, hx=None):
if hx is None: if hx is None:
hx = torch.autograd.Variable(input.data.new(input.size(1), hx = torch.autograd.Variable(input_state.data.new(input_state.size(1),
self.hidden_size self.hidden_size
).zero_(), requires_grad=False) ).zero_(), requires_grad=False)
self.check_forward_input(input) self.check_forward_input(input_state)
self.check_forward_hidden(input, hx) self.check_forward_hidden(input_state, hx)
hidden = hx hidden = hx
# compute this once for all steps for efficiency
a = F.sigmoid(self.bias_forget)
output = [] output = []
steps = range(input.size(0)) steps = range(input_state.size(0))
for i in steps: for i in steps:
hidden = self.step(input[i], hidden) hidden = self.step(input_state[i], hidden, a=a)
output.append(hidden) output.append(hidden)
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) output = torch.cat(output, 0).view(input_state.size(0), *output[0].size())
return output, hidden return output, hidden
......
...@@ -28,9 +28,9 @@ along with neural_filters. If not, see <http://www.gnu.org/licenses/>. ...@@ -28,9 +28,9 @@ along with neural_filters. If not, see <http://www.gnu.org/licenses/>.
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from torch.nn import functional as F from torch.nn import functional as F
import math
from . import NeuralFilter from . import NeuralFilter
class NeuralFilter1L(NeuralFilter): class NeuralFilter1L(NeuralFilter):
""" """
A trainable first-order all-pole filter :math:`\\frac{K}{1 - P z^{-1}}` with bias on the input A trainable first-order all-pole filter :math:`\\frac{K}{1 - P z^{-1}}` with bias on the input
...@@ -54,19 +54,15 @@ class NeuralFilter1L(NeuralFilter): ...@@ -54,19 +54,15 @@ class NeuralFilter1L(NeuralFilter):
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def reset_parameters(self, init=None): def reset_parameters(self, init=None):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
super(NeuralFilter1L, self).reset_parameters(init) super(NeuralFilter1L, self).reset_parameters(init)
def check_forward_input(self, input): def check_forward_input(self, input_state):
if input.size(-1) != self.input_size: if input_state.size(-1) != self.input_size:
raise RuntimeError( raise RuntimeError(
"input has inconsistent input_size(-1): got {}, expected {}".format( "input has inconsistent input_size(-1): got {}, expected {}".format(
input.size(1), self.input_size)) input_state.size(1), self.input_size))
def step(self, input, hidden): def step(self, input_state, hidden):
in_gate = F.linear(input, self.weight_in, self.bias_in) in_gate = F.linear(input_state, self.weight_in, self.bias_in)
next = super(NeuralFilter1L, self).step(in_gate, hidden) next_state = super(NeuralFilter1L, self).step(in_gate, hidden)
return next return next_state
...@@ -50,13 +50,12 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -50,13 +50,12 @@ class NeuralFilter2CC(torch.nn.Module):
def reset_parameters(self, init=None): def reset_parameters(self, init=None):
if init is None: if init is None:
self.bias_modulus.data.zero_() self.bias_modulus.data.uniform_(-0.2, 0.2)
self.bias_theta.data.zero_() self.bias_theta.data.uniform_(-0.2, 0.2)
else: else:
if isinstance(init, tuple): if isinstance(init, tuple):
self.bias_modulus.data.fill_(init[0]) self.bias_modulus.data.fill_(init[0])
self.bias_theta.data.fill_(init[1]) self.bias_theta.data.fill_(init[1])
else: else:
self.bias_theta.data.fill_(init) self.bias_theta.data.fill_(init)
self.bias_modulus.data.fill_(init) self.bias_modulus.data.fill_(init)
...@@ -65,58 +64,64 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -65,58 +64,64 @@ class NeuralFilter2CC(torch.nn.Module):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def check_forward_input(self, input): def check_forward_input(self, input_var):
if input.size(-1) != self.hidden_size: if input_var.size(-1) != self.hidden_size:
raise RuntimeError( raise RuntimeError(
"input has inconsistent input_size(-1): got {}, expected {}".format( "input has inconsistent input_size(-1): got {}, expected {}".format(
input.size(1), self.hidden_size)) input_var.size(1), self.hidden_size))
def check_forward_hidden(self, input, hx): def check_forward_hidden(self, input_var, hx):
if input.size(1) != hx.size(0): if input_var.size(1) != hx.size(0):
raise RuntimeError( raise RuntimeError(
"Input batch size {} doesn't match hidden batch size {}".format( "Input batch size {} doesn't match hidden batch size {}".format(
input.size(1), hx.size(0))) input_var.size(1), hx.size(0)))
if hx.size(1) != self.hidden_size: if hx.size(1) != self.hidden_size:
raise RuntimeError( raise RuntimeError(
"hidden has inconsistent hidden_size: got {}, expected {}".format( "hidden has inconsistent hidden_size: got {}, expected {}".format(
hx.size(1), self.hidden_size)) hx.size(1), self.hidden_size))
def step(self, input, delayed, delayed2): def step(self, input_var, delayed, delayed2, a=None, b=None):
if a is None or b is None:
modulus = F.sigmoid(self.bias_modulus) modulus = F.sigmoid(self.bias_modulus)
cosangle = F.tanh(self.bias_theta) cosangle = F.tanh(self.bias_theta)
a = 2 * cosangle * modulus
b = - modulus.pow(2)
next = input + 2 * cosangle * modulus * delayed - modulus.pow(2) * delayed2 next_state = input_var + a * delayed + b * delayed2
return next return next_state
def forward(self, input, delayed=None, delayed2=None): def forward(self, input_var, delayed=None, delayed2=None):
if delayed is None: if delayed is None:
delayed = torch.autograd.Variable(input.data.new(input.size(1), delayed = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
self.hidden_size requires_grad=False)
).zero_(), requires_grad=False)
if delayed2 is None: if delayed2 is None:
delayed2 = torch.autograd.Variable(input.data.new(input.size(1), delayed2 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
self.hidden_size requires_grad=False)
).zero_(), requires_grad=False)
self.check_forward_input(input) self.check_forward_input(input_var)
self.check_forward_hidden(input, delayed) self.check_forward_hidden(input_var, delayed)
self.check_forward_hidden(input, delayed2) self.check_forward_hidden(input_var, delayed2)
d1 = delayed d1 = delayed
d2 = delayed2 d2 = delayed2
# do not recompute this at each step to gain efficiency
modulus = F.sigmoid(self.bias_modulus)
cosangle = F.tanh(self.bias_theta)
a = 2 * cosangle * modulus
b = - modulus.pow(2)
output = [] output = []
steps = range(input.size(0)) steps = range(input_var.size(0))
for i in steps: for i in steps:
next = self.step(input[i], d1, d2) next_state = self.step(input_var[i], d1, d2, a=a, b=b)
output.append(next) output.append(next_state)
d2, d1 = d1, next d2, d1 = d1, next_state
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) output = torch.cat(output, 0).view(input_var.size(0), *output[0].size())
return output, d1, d2 return output, d1, d2
...@@ -135,7 +140,7 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -135,7 +140,7 @@ class NeuralFilter2CC(torch.nn.Module):
p2 = modulus.pow(2) p2 = modulus.pow(2)
p1 = p1.data.numpy() p1 = p1.data.numpy()
p2 = p2.data.numpy() p2 = p2.data.numpy()
p1 = p1.reshape(p1.size,1) p1 = p1.reshape(p1.size, 1)
p2 = p2.reshape(p2.size, 1) p2 = p2.reshape(p2.size, 1)
one = np.ones(p1.shape) one = np.ones(p1.shape)
......
...@@ -30,7 +30,8 @@ from . import NeuralFilter ...@@ -30,7 +30,8 @@ from . import NeuralFilter
import torch import torch
import numpy as np import numpy as np
class NeuralFilter2CD (torch.nn.Module):
class NeuralFilter2CD(torch.nn.Module):
""" """
A trainable second-order critically damped all-pole filter :math:`\\frac{1}{(1 - P z^{-1})^{2}}` A trainable second-order critically damped all-pole filter :math:`\\frac{1}{(1 - P z^{-1})^{2}}`
...@@ -51,14 +52,14 @@ class NeuralFilter2CD (torch.nn.Module): ...@@ -51,14 +52,14 @@ class NeuralFilter2CD (torch.nn.Module):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input, hx=None): def forward(self, input_var, hx=None):
if hx is None: if hx is None:
hx = torch.autograd.Variable(input.data.new(input.size(1), hx = torch.autograd.Variable(input_var.data.new(input_var.size(1),
self.hidden_size self.hidden_size
).zero_(), requires_grad=False) ).zero_(), requires_grad=False)
interm, interm_hidden = self.cell(input, hx) inter, inter_hidden = self.cell(input_var, hx)
output, hidden = self.cell(interm) output, hidden = self.cell(inter)
return output, hidden return output, hidden
......
...@@ -30,7 +30,8 @@ from . import NeuralFilter ...@@ -30,7 +30,8 @@ from . import NeuralFilter
import torch import torch
import numpy as np import numpy as np
class NeuralFilter2R (torch.nn.Module):
class NeuralFilter2R(torch.nn.Module):
""" """
A trainable second-order all-(real)pole filter :math:`\\frac{1}{1 - P_{1} z^{-1}} \\frac{1}{1 - P_{2} z^{-1}}` A trainable second-order all-(real)pole filter :math:`\\frac{1}{1 - P_{1} z^{-1}} \\frac{1}{1 - P_{2} z^{-1}}`
...@@ -45,10 +46,13 @@ class NeuralFilter2R (torch.nn.Module): ...@@ -45,10 +46,13 @@ class NeuralFilter2R (torch.nn.Module):
self.first_cell = NeuralFilter(self.hidden_size) self.first_cell = NeuralFilter(self.hidden_size)
self.second_cell = NeuralFilter(self.hidden_size) self.second_cell = NeuralFilter(self.hidden_size)
self.reset_parameters((-0.5, 0.5)) self.reset_parameters()
def reset_parameters(self, init=None): def reset_parameters(self, init=None):
if isinstance(init, tuple): if init is None:
self.first_cell.bias_forget.data.uniform_(-0.7, -0.3)
self.second_cell.bias_forget.data.uniform_(0.3, 0.7)
elif isinstance(init, tuple):
self.first_cell.reset_parameters(init[0]) self.first_cell.reset_parameters(init[0])
self.second_cell.reset_parameters(init[1]) self.second_cell.reset_parameters(init[1])
else: else:
...@@ -59,13 +63,13 @@ class NeuralFilter2R (torch.nn.Module): ...@@ -59,13 +63,13 @@ class NeuralFilter2R (torch.nn.Module):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input, hx=None): def forward(self, input_var, hx=None):
if hx is None: if hx is None:
hx = torch.autograd.Variable(input.data.new(input.size(1), hx = torch.autograd.Variable(input_var.data.new(input_var.size(1),
self.hidden_size self.hidden_size
).zero_(), requires_grad=False) ).zero_(), requires_grad=False)
interm, interm_hidden = self.first_cell(input, hx) interm, interm_hidden = self.first_cell(input_var, hx)
output, hidden = self.second_cell(interm) output, hidden = self.second_cell(interm)
return output, hidden return output, hidden
...@@ -74,7 +78,7 @@ class NeuralFilter2R (torch.nn.Module): ...@@ -74,7 +78,7 @@ class NeuralFilter2R (torch.nn.Module):
def denominator(self): def denominator(self):
first = self.first_cell.denominator first = self.first_cell.denominator
second = self.second_cell.denominator second = self.second_cell.denominator
denom = np.zeros((first.shape[0],3)) denom = np.zeros((first.shape[0], 3))
for i in range(self.hidden_size): for i in range(self.hidden_size):
denom[i] = np.polymul(first[i], second[i]) denom[i] = np.polymul(first[i], second[i])
return denom return denom
...@@ -83,4 +87,4 @@ class NeuralFilter2R (torch.nn.Module): ...@@ -83,4 +87,4 @@ class NeuralFilter2R (torch.nn.Module):
def gradients(self): def gradients(self):
first = self.first_cell.gradients first = self.first_cell.gradients
second = self.second_cell.gradients second = self.second_cell.gradients
return np.concatenate((first, second), axis=1) return np.concatenate((first, second), axis=1)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment