Commit 86e3aa0c authored by M. François's avatar M. François

PackedSequence compliance

parent b0ec956f
...@@ -27,7 +27,6 @@ def atanh(x): ...@@ -27,7 +27,6 @@ def atanh(x):
from .log_loss import * from .log_loss import *
from .neural_filter import * from .neural_filter import *
from .neural_filter_1L import *
from .neural_filter_2CC import * from .neural_filter_2CC import *
from .neural_filter_2CD import * from .neural_filter_2CD import *
from .neural_filter_2R import * from .neural_filter_2R import *
...@@ -25,12 +25,15 @@ along with neural_filters. If not, see <http://www.gnu.org/licenses/>. ...@@ -25,12 +25,15 @@ along with neural_filters. If not, see <http://www.gnu.org/licenses/>.
""" """
import numpy as np
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from torch.nn import functional as F from torch.nn import functional as F
import numpy as np from torch.nn._functions.rnn import Recurrent, VariableRecurrent
from torch.nn.utils.rnn import PackedSequence
from . import asig
from . import INIT_MODULUS, asig
class NeuralFilter(torch.nn.Module): class NeuralFilter(torch.nn.Module):
""" """
...@@ -66,53 +69,63 @@ class NeuralFilter(torch.nn.Module): ...@@ -66,53 +69,63 @@ class NeuralFilter(torch.nn.Module):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def check_forward_input(self, input_state): def check_forward_args(self, input_var, hidden, batch_sizes):
if input_state.size(-1) != self.hidden_size: is_input_packed = batch_sizes is not None
expected_input_dim = 2 if is_input_packed else 3
if input_var.dim() != expected_input_dim:
raise RuntimeError( raise RuntimeError(
"input has inconsistent input_size(-1): got {}, expected {}".format( 'input must have {} dimensions, got {}'.format(
input_state.size(1), self.hidden_size)) expected_input_dim, input_var.dim()))
if self.hidden_size != input_var.size(-1):
def check_forward_hidden(self, input_state, hx):
if input_state.size(1) != hx.size(0):
raise RuntimeError( raise RuntimeError(
"Input batch size {} doesn't match hidden batch size {}".format( 'input.size(-1) must be equal to hidden_size. Expected {}, got {}'.format(
input_state.size(1), hx.size(0))) self.input_size, input_var.size(-1)))
if hx.size(1) != self.hidden_size: if is_input_packed:
raise RuntimeError( mini_batch = int(batch_sizes[0])
"hidden has inconsistent hidden_size: got {}, expected {}".format( else:
hx.size(1), self.hidden_size)) mini_batch = input_var.size(1)
expected_hidden_size = (mini_batch, self.hidden_size)
def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
if tuple(hx.size()) != expected_hidden_size:
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
def step(self, input_state, hidden, a=None): check_hidden_size(hidden, expected_hidden_size,
'Expected hidden[0] size {}, got {}')
def step(self, input_var, hidden, a=None):
if a is None: if a is None:
a = F.sigmoid(self.bias_forget) a = F.sigmoid(self.bias_forget)
next_state = (a * hidden) + input_state next_state = (a * hidden) + input_var
return next_state return next_state
def forward(self, input_state, hx=None): def forward(self, input_var, hidden=None):
if hx is None: is_packed = isinstance(input_var, PackedSequence)
hx = torch.autograd.Variable(input_state.data.new(input_state.size(1), if is_packed:
self.hidden_size input_var, batch_sizes = input_var
).zero_(), requires_grad=False) max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
max_batch_size = input_var.size(1)
self.check_forward_input(input_state) if hidden is None:
self.check_forward_hidden(input_state, hx) hidden = input_var.data.new_zeros(max_batch_size, self.hidden_size, requires_grad=False)
hidden = hx self.check_forward_args(input_var, hidden, batch_sizes)
# compute this once for all steps for efficiency # compute this once for all steps for efficiency
a = F.sigmoid(self.bias_forget) a = F.sigmoid(self.bias_forget)
output = [] func = Recurrent(self.step) if batch_sizes is None else VariableRecurrent(self.step)
steps = range(input_state.size(0)) nexth, output = func(input_var, hidden, (a,), batch_sizes)
for i in steps:
hidden = self.step(input_state[i], hidden, a=a)
output.append(hidden)
output = torch.cat(output, 0).view(input_state.size(0), *output[0].size()) if is_packed:
output = PackedSequence(output, batch_sizes)
return output, hidden return output, nexth
@property @property
def gradients(self): def gradients(self):
......
"""
NeuralFilter1P
**************
This module implements a trainable all-pole first order with linear combination input filter using pyTorch
Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/
Written by Francois Marelli <Francois.Marelli@idiap.ch>
This file is part of neural_filters.
neural_filters is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.
neural_filters is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with neural_filters. If not, see <http://www.gnu.org/licenses/>.
"""
import torch
from torch.nn import Parameter
from torch.nn import functional as F
from . import NeuralFilter
class NeuralFilter1L(NeuralFilter):
"""
A trainable first-order all-pole filter :math:`\\frac{K}{1 - P z^{-1}}` with bias on the input
* **input_size** (int) - the size of the input vector
* **hidden_size** (int) - the size of the output vector
"""
def __init__(self, input_size, hidden_size):
super(NeuralFilter1L, self).__init__(hidden_size)
self.input_size = input_size
self.weight_in = Parameter(torch.Tensor(hidden_size, input_size))
self.bias_in = Parameter(torch.Tensor(hidden_size))
self.reset_parameters()
def __repr__(self):
s = '{name}({input_size},{hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)
def reset_parameters(self, init=None):
super(NeuralFilter1L, self).reset_parameters(init)
def check_forward_input(self, input_state):
if input_state.size(-1) != self.input_size:
raise RuntimeError(
"input has inconsistent input_size(-1): got {}, expected {}".format(
input_state.size(1), self.input_size))
def step(self, input_state, hidden):
in_gate = F.linear(input_state, self.weight_in, self.bias_in)
next_state = super(NeuralFilter1L, self).step(in_gate, hidden)
return next_state
...@@ -29,6 +29,8 @@ import numpy as np ...@@ -29,6 +29,8 @@ import numpy as np
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn._functions.rnn import Recurrent, VariableRecurrent
from torch.nn.utils.rnn import PackedSequence
from . import MIN_ANGLE, MAX_ANGLE, INIT_MODULUS, asig, atanh from . import MIN_ANGLE, MAX_ANGLE, INIT_MODULUS, asig, atanh
...@@ -82,22 +84,33 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -82,22 +84,33 @@ class NeuralFilter2CC(torch.nn.Module):
s = '{name}({hidden_size})' s = '{name}({hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def check_forward_input(self, input_var): def check_forward_args(self, input_var, hidden, batch_sizes):
if input_var.size(-1) != self.hidden_size: is_input_packed = batch_sizes is not None
expected_input_dim = 2 if is_input_packed else 3
if input_var.dim() != expected_input_dim:
raise RuntimeError( raise RuntimeError(
"input has inconsistent input_size(-1): got {}, expected {}".format( 'input must have {} dimensions, got {}'.format(
input_var.size(1), self.hidden_size)) expected_input_dim, input_var.dim()))
if self.hidden_size != input_var.size(-1):
def check_forward_hidden(self, input_var, hx):
if input_var.size(1) != hx.size(0):
raise RuntimeError( raise RuntimeError(
"Input batch size {} doesn't match hidden batch size {}".format( 'input.size(-1) must be equal to hidden_size. Expected {}, got {}'.format(
input_var.size(1), hx.size(0))) self.input_size, input_var.size(-1)))
if hx.size(1) != self.hidden_size: if is_input_packed:
raise RuntimeError( mini_batch = int(batch_sizes[0])
"hidden has inconsistent hidden_size: got {}, expected {}".format( else:
hx.size(1), self.hidden_size)) mini_batch = input_var.size(1)
expected_hidden_size = (mini_batch, self.hidden_size)
def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
if tuple(hx.size()) != expected_hidden_size:
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
check_hidden_size(hidden[0], expected_hidden_size,
'Expected hidden[0] size {}, got {}')
check_hidden_size(hidden[1], expected_hidden_size,
'Expected hidden[1] size {}, got {}')
def step(self, input_var, hidden, a=None, b=None): def step(self, input_var, hidden, a=None, b=None):
if a is None or b is None: if a is None or b is None:
...@@ -108,21 +121,23 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -108,21 +121,23 @@ class NeuralFilter2CC(torch.nn.Module):
next_state = input_var + a * hidden[0] + b * hidden[1] next_state = input_var + a * hidden[0] + b * hidden[1]
return next_state return next_state, hidden[0]
def forward(self, input_var, hidden=None):
is_packed = isinstance(input_var, PackedSequence)
if is_packed:
input_var, batch_sizes = input_var
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
max_batch_size = input_var.size(1)
def forward(self, input_var, hidden=(None, None)): if hidden is None:
h0, h1 = hidden h = input_var.new_zeros(max_batch_size, self.hidden_size, requires_grad=False)
if h0 is None:
h0 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
requires_grad=False)
if h1 is None: hidden = (h, h)
h1 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
requires_grad=False)
self.check_forward_input(input_var) self.check_forward_args(input_var, hidden, batch_sizes)
self.check_forward_hidden(input_var, h0)
self.check_forward_hidden(input_var, h1)
# do not recompute this at each step to gain efficiency # do not recompute this at each step to gain efficiency
modulus = F.sigmoid(self.bias_modulus) modulus = F.sigmoid(self.bias_modulus)
...@@ -130,16 +145,13 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -130,16 +145,13 @@ class NeuralFilter2CC(torch.nn.Module):
a = 2 * cosangle * modulus a = 2 * cosangle * modulus
b = - modulus.pow(2) b = - modulus.pow(2)
output = [] func = Recurrent(self.step) if batch_sizes is None else VariableRecurrent(self.step)
steps = range(input_var.size(0)) nexth, output = func(input_var, hidden, (a, b), batch_sizes)
for i in steps:
next_state = self.step(input_var[i], (h0, h1), a=a, b=b)
output.append(next_state)
h1, h0 = h0, next_state
output = torch.cat(output, 0).view(input_var.size(0), *output[0].size()) if is_packed:
output = PackedSequence(output, batch_sizes)
return output, (h0, h1) return output, nexth
def print_param(self): def print_param(self):
modulus = F.sigmoid(self.bias_modulus) modulus = F.sigmoid(self.bias_modulus)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment