neural_filter_2CC.py 5.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
"""
NeuralFilter2CC
***************

This module implements a trainable all-pole second order filter with complex conjugate poles using pyTorch


Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/

Written by Francois Marelli <Francois.Marelli@idiap.ch>

This file is part of neural_filters.

neural_filters is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

neural_filters is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with neural_filters. If not, see <http://www.gnu.org/licenses/>.

"""

import torch
from torch.nn import Parameter
from torch.nn import functional as F
import numpy as np


class NeuralFilter2CC(torch.nn.Module):
    """
    A trainable second-order all-pole filter :math:`\\frac{1}{1 - 2 P \\cos(\\theta) z^{-1} + P^{2} z^{-2}}`

    * **hidden_size** (int) - the size of the data vector
    """

    def __init__(self, hidden_size):
        super(NeuralFilter2CC, self).__init__()

        self.hidden_size = hidden_size

        self.bias_theta = Parameter(torch.Tensor(hidden_size))
        self.bias_modulus = Parameter(torch.Tensor(hidden_size))

        self.reset_parameters()

Francois Marelli's avatar
Francois Marelli committed
51 52
    def reset_parameters(self, init=None):
        if init is None:
53 54
            self.bias_modulus.data.uniform_(-0.2, 0.2)
            self.bias_theta.data.uniform_(-0.2, 0.2)
Francois Marelli's avatar
Francois Marelli committed
55 56 57 58 59 60 61
        else:
            if isinstance(init, tuple):
                self.bias_modulus.data.fill_(init[0])
                self.bias_theta.data.fill_(init[1])
            else:
                self.bias_theta.data.fill_(init)
                self.bias_modulus.data.fill_(init)
62 63 64 65 66

    def __repr__(self):
        s = '{name}({hidden_size})'
        return s.format(name=self.__class__.__name__, **self.__dict__)

67 68
    def check_forward_input(self, input_var):
        if input_var.size(-1) != self.hidden_size:
69 70
            raise RuntimeError(
                "input has inconsistent input_size(-1): got {}, expected {}".format(
71
                    input_var.size(1), self.hidden_size))
72

73 74
    def check_forward_hidden(self, input_var, hx):
        if input_var.size(1) != hx.size(0):
75 76
            raise RuntimeError(
                "Input batch size {} doesn't match hidden batch size {}".format(
77
                    input_var.size(1), hx.size(0)))
78 79 80 81 82 83

        if hx.size(1) != self.hidden_size:
            raise RuntimeError(
                "hidden has inconsistent hidden_size: got {}, expected {}".format(
                    hx.size(1), self.hidden_size))

84 85 86 87 88 89
    def step(self, input_var, delayed, delayed2, a=None, b=None):
        if a is None or b is None:
            modulus = F.sigmoid(self.bias_modulus)
            cosangle = F.tanh(self.bias_theta)
            a = 2 * cosangle * modulus
            b = - modulus.pow(2)
90

91
        next_state = input_var + a * delayed + b * delayed2
92

93
        return next_state
94

95
    def forward(self, input_var, delayed=None, delayed2=None):
96
        if delayed is None:
97 98
            delayed = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
                                              requires_grad=False)
99 100

        if delayed2 is None:
101 102
            delayed2 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
                                               requires_grad=False)
103

104 105 106
        self.check_forward_input(input_var)
        self.check_forward_hidden(input_var, delayed)
        self.check_forward_hidden(input_var, delayed2)
107 108 109 110

        d1 = delayed
        d2 = delayed2

111 112 113 114 115 116
        # do not recompute this at each step to gain efficiency
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        a = 2 * cosangle * modulus
        b = - modulus.pow(2)

117
        output = []
118
        steps = range(input_var.size(0))
119
        for i in steps:
120 121 122
            next_state = self.step(input_var[i], d1, d2, a=a, b=b)
            output.append(next_state)
            d2, d1 = d1, next_state
123

124
        output = torch.cat(output, 0).view(input_var.size(0), *output[0].size())
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

        return output, d1, d2

    def print_param(self):
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        p1 = -2 * cosangle * modulus
        p2 = modulus.pow(2)
        print('{}\t{}'.format(p1.data[0], p2.data[0]))

    @property
    def denominator(self):
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        p1 = -2 * cosangle * modulus
        p2 = modulus.pow(2)
        p1 = p1.data.numpy()
        p2 = p2.data.numpy()
143
        p1 = p1.reshape(p1.size, 1)
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        p2 = p2.reshape(p2.size, 1)
        one = np.ones(p1.shape)

        denom = np.concatenate((one, p1, p2), axis=1)
        return denom

    @property
    def gradients(self):
        mod_grad = self.bias_modulus.grad
        if mod_grad is not None:
            mod_grad = mod_grad.data.numpy()
            mod_grad = mod_grad.reshape(mod_grad.size, 1)
            cos_grad = self.bias_theta.grad.data.numpy()
            cos_grad = cos_grad.reshape(cos_grad.size, 1)
            return np.concatenate((mod_grad, cos_grad), axis=1)
        else:
            return np.zeros((self.hidden_size, 2))