neural_filter_2CC.py 5.94 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
"""
NeuralFilter2CC
***************

This module implements a trainable all-pole second order filter with complex conjugate poles using pyTorch


Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/

Written by Francois Marelli <Francois.Marelli@idiap.ch>

This file is part of neural_filters.

neural_filters is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

neural_filters is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with neural_filters. If not, see <http://www.gnu.org/licenses/>.

"""

Francois Marelli's avatar
Francois Marelli committed
28
import numpy as np
29 30 31 32
import torch
from torch.nn import Parameter
from torch.nn import functional as F

Francois Marelli's avatar
Francois Marelli committed
33
from . import MIN_ANGLE, MAX_ANGLE, INIT_MODULUS, asig, atanh
34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

class NeuralFilter2CC(torch.nn.Module):
    """
    A trainable second-order all-pole filter :math:`\\frac{1}{1 - 2 P \\cos(\\theta) z^{-1} + P^{2} z^{-2}}`

    * **hidden_size** (int) - the size of the data vector
    """

    def __init__(self, hidden_size):
        super(NeuralFilter2CC, self).__init__()

        self.hidden_size = hidden_size

        self.bias_theta = Parameter(torch.Tensor(hidden_size))
        self.bias_modulus = Parameter(torch.Tensor(hidden_size))

        self.reset_parameters()

Francois Marelli's avatar
Francois Marelli committed
53 54 55
    def reset_parameters(self, init_modulus=None, init_theta=None,
                         min_angle=MIN_ANGLE, max_angle=MAX_ANGLE, modulus=INIT_MODULUS):

56
        if init_modulus is None:
Francois Marelli's avatar
Francois Marelli committed
57
            init_modulus = asig(modulus)
58

Francois Marelli's avatar
Francois Marelli committed
59 60
        if not isinstance(init_modulus, np.ndarray):
            init_modulus = np.array(init_modulus, ndmin=1)
61

Francois Marelli's avatar
Francois Marelli committed
62 63
        ten_init = torch.from_numpy(init_modulus)
        self.bias_modulus.data.copy_(ten_init)
64 65

        if init_theta is None:
Francois Marelli's avatar
Francois Marelli committed
66 67 68 69 70 71 72
            parts = self.hidden_size * 2
            ranges = np.arange(1, parts, 2)

            init_angle = ranges * (max_angle - min_angle) / parts + min_angle

            cosangle = np.cos(init_angle)
            init_theta = atanh(cosangle)
73

Francois Marelli's avatar
Francois Marelli committed
74 75
        if not isinstance(init_theta, np.ndarray):
            init_theta = np.array(init_theta, ndmin=1)
76

Francois Marelli's avatar
Francois Marelli committed
77 78
        ten_init = torch.from_numpy(init_theta)
        self.bias_theta.data.copy_(ten_init)
79 80 81 82 83

    def __repr__(self):
        s = '{name}({hidden_size})'
        return s.format(name=self.__class__.__name__, **self.__dict__)

84 85
    def check_forward_input(self, input_var):
        if input_var.size(-1) != self.hidden_size:
86 87
            raise RuntimeError(
                "input has inconsistent input_size(-1): got {}, expected {}".format(
88
                    input_var.size(1), self.hidden_size))
89

90 91
    def check_forward_hidden(self, input_var, hx):
        if input_var.size(1) != hx.size(0):
92 93
            raise RuntimeError(
                "Input batch size {} doesn't match hidden batch size {}".format(
94
                    input_var.size(1), hx.size(0)))
95 96 97 98 99 100

        if hx.size(1) != self.hidden_size:
            raise RuntimeError(
                "hidden has inconsistent hidden_size: got {}, expected {}".format(
                    hx.size(1), self.hidden_size))

101 102 103 104 105 106
    def step(self, input_var, delayed, delayed2, a=None, b=None):
        if a is None or b is None:
            modulus = F.sigmoid(self.bias_modulus)
            cosangle = F.tanh(self.bias_theta)
            a = 2 * cosangle * modulus
            b = - modulus.pow(2)
107

108
        next_state = input_var + a * delayed + b * delayed2
109

110
        return next_state
111

112
    def forward(self, input_var, delayed=None, delayed2=None):
113
        if delayed is None:
114 115
            delayed = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
                                              requires_grad=False)
116 117

        if delayed2 is None:
118 119
            delayed2 = torch.autograd.Variable(input_var.data.new(input_var.size(1), self.hidden_size).zero_(),
                                               requires_grad=False)
120

121 122 123
        self.check_forward_input(input_var)
        self.check_forward_hidden(input_var, delayed)
        self.check_forward_hidden(input_var, delayed2)
124 125 126 127

        d1 = delayed
        d2 = delayed2

128 129 130 131 132 133
        # do not recompute this at each step to gain efficiency
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        a = 2 * cosangle * modulus
        b = - modulus.pow(2)

134
        output = []
135
        steps = range(input_var.size(0))
136
        for i in steps:
137 138 139
            next_state = self.step(input_var[i], d1, d2, a=a, b=b)
            output.append(next_state)
            d2, d1 = d1, next_state
140

141
        output = torch.cat(output, 0).view(input_var.size(0), *output[0].size())
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159

        return output, d1, d2

    def print_param(self):
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        p1 = -2 * cosangle * modulus
        p2 = modulus.pow(2)
        print('{}\t{}'.format(p1.data[0], p2.data[0]))

    @property
    def denominator(self):
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        p1 = -2 * cosangle * modulus
        p2 = modulus.pow(2)
        p1 = p1.data.numpy()
        p2 = p2.data.numpy()
160
        p1 = p1.reshape(p1.size, 1)
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        p2 = p2.reshape(p2.size, 1)
        one = np.ones(p1.shape)

        denom = np.concatenate((one, p1, p2), axis=1)
        return denom

    @property
    def gradients(self):
        mod_grad = self.bias_modulus.grad
        if mod_grad is not None:
            mod_grad = mod_grad.data.numpy()
            mod_grad = mod_grad.reshape(mod_grad.size, 1)
            cos_grad = self.bias_theta.grad.data.numpy()
            cos_grad = cos_grad.reshape(cos_grad.size, 1)
            return np.concatenate((mod_grad, cos_grad), axis=1)
        else:
            return np.zeros((self.hidden_size, 2))