neural_filter_2CC.py 6.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
"""
NeuralFilter2CC
***************

This module implements a trainable all-pole second order filter with complex conjugate poles using pyTorch


Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/

Written by Francois Marelli <Francois.Marelli@idiap.ch>

This file is part of neural_filters.

neural_filters is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

neural_filters is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with neural_filters. If not, see <http://www.gnu.org/licenses/>.

"""

Francois Marelli's avatar
Francois Marelli committed
28
import numpy as np
29 30 31
import torch
from torch.nn import Parameter
from torch.nn import functional as F
M. François's avatar
M. François committed
32 33
from torch.nn._functions.rnn import Recurrent, VariableRecurrent
from torch.nn.utils.rnn import PackedSequence
34

Francois Marelli's avatar
Francois Marelli committed
35
from . import MIN_ANGLE, MAX_ANGLE, INIT_MODULUS, asig, atanh
36

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

class NeuralFilter2CC(torch.nn.Module):
    """
    A trainable second-order all-pole filter :math:`\\frac{1}{1 - 2 P \\cos(\\theta) z^{-1} + P^{2} z^{-2}}`

    * **hidden_size** (int) - the size of the data vector
    """

    def __init__(self, hidden_size):
        super(NeuralFilter2CC, self).__init__()

        self.hidden_size = hidden_size

        self.bias_theta = Parameter(torch.Tensor(hidden_size))
        self.bias_modulus = Parameter(torch.Tensor(hidden_size))

        self.reset_parameters()

Francois Marelli's avatar
Francois Marelli committed
55 56 57
    def reset_parameters(self, init_modulus=None, init_theta=None,
                         min_angle=MIN_ANGLE, max_angle=MAX_ANGLE, modulus=INIT_MODULUS):

58
        if init_modulus is None:
M. François's avatar
M. François committed
59
            init_modulus = modulus
60

Francois Marelli's avatar
Francois Marelli committed
61 62
        if not isinstance(init_modulus, np.ndarray):
            init_modulus = np.array(init_modulus, ndmin=1)
63

M. François's avatar
M. François committed
64 65
        init_mod = asig(init_modulus)
        ten_init = torch.from_numpy(init_mod)
Francois Marelli's avatar
Francois Marelli committed
66
        self.bias_modulus.data.copy_(ten_init)
67 68

        if init_theta is None:
Francois Marelli's avatar
Francois Marelli committed
69 70 71
            parts = self.hidden_size * 2
            ranges = np.arange(1, parts, 2)

M. François's avatar
M. François committed
72
            init_theta = ranges * (max_angle - min_angle) / parts + min_angle
73

Francois Marelli's avatar
Francois Marelli committed
74 75
        if not isinstance(init_theta, np.ndarray):
            init_theta = np.array(init_theta, ndmin=1)
76

M. François's avatar
M. François committed
77 78 79 80
        cosangle = np.cos(init_theta)
        init_angle = atanh(cosangle)

        ten_init = torch.from_numpy(init_angle)
Francois Marelli's avatar
Francois Marelli committed
81
        self.bias_theta.data.copy_(ten_init)
82 83 84 85 86

    def __repr__(self):
        s = '{name}({hidden_size})'
        return s.format(name=self.__class__.__name__, **self.__dict__)

M. François's avatar
M. François committed
87 88 89 90
    def check_forward_args(self, input_var, hidden, batch_sizes):
        is_input_packed = batch_sizes is not None
        expected_input_dim = 2 if is_input_packed else 3
        if input_var.dim() != expected_input_dim:
91
            raise RuntimeError(
M. François's avatar
M. François committed
92 93 94
                'input must have {} dimensions, got {}'.format(
                    expected_input_dim, input_var.dim()))
        if self.hidden_size != input_var.size(-1):
95
            raise RuntimeError(
M. François's avatar
M. François committed
96 97
                'input.size(-1) must be equal to hidden_size. Expected {}, got {}'.format(
                    self.input_size, input_var.size(-1)))
98

M. François's avatar
M. François committed
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        if is_input_packed:
            mini_batch = int(batch_sizes[0])
        else:
            mini_batch = input_var.size(1)

        expected_hidden_size = (mini_batch, self.hidden_size)

        def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
            if tuple(hx.size()) != expected_hidden_size:
                raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))

        check_hidden_size(hidden[0], expected_hidden_size,
                          'Expected hidden[0] size {}, got {}')
        check_hidden_size(hidden[1], expected_hidden_size,
                          'Expected hidden[1] size {}, got {}')
114

M. François's avatar
M. François committed
115
    def step(self, input_var, hidden, a=None, b=None):
116 117 118 119 120
        if a is None or b is None:
            modulus = F.sigmoid(self.bias_modulus)
            cosangle = F.tanh(self.bias_theta)
            a = 2 * cosangle * modulus
            b = - modulus.pow(2)
121

M. François's avatar
M. François committed
122
        next_state = input_var + a * hidden[0] + b * hidden[1]
123

M. François's avatar
M. François committed
124 125 126 127 128 129 130 131 132 133
        return next_state, hidden[0]

    def forward(self, input_var, hidden=None):
        is_packed = isinstance(input_var, PackedSequence)
        if is_packed:
            input_var, batch_sizes = input_var
            max_batch_size = int(batch_sizes[0])
        else:
            batch_sizes = None
            max_batch_size = input_var.size(1)
134

M. François's avatar
M. François committed
135 136
        if hidden is None:
            h = input_var.new_zeros(max_batch_size, self.hidden_size, requires_grad=False)
137

M. François's avatar
M. François committed
138
            hidden = (h, h)
139

M. François's avatar
M. François committed
140
        self.check_forward_args(input_var, hidden, batch_sizes)
141

142 143 144 145 146 147
        # do not recompute this at each step to gain efficiency
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        a = 2 * cosangle * modulus
        b = - modulus.pow(2)

M. François's avatar
M. François committed
148 149
        func = Recurrent(self.step) if batch_sizes is None else VariableRecurrent(self.step)
        nexth, output = func(input_var, hidden, (a, b), batch_sizes)
150

M. François's avatar
M. François committed
151 152
        if is_packed:
            output = PackedSequence(output, batch_sizes)
153

M. François's avatar
M. François committed
154
        return output, nexth, modulus
155 156 157 158 159 160 161 162 163 164 165 166 167 168

    def print_param(self):
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        p1 = -2 * cosangle * modulus
        p2 = modulus.pow(2)
        print('{}\t{}'.format(p1.data[0], p2.data[0]))

    @property
    def denominator(self):
        modulus = F.sigmoid(self.bias_modulus)
        cosangle = F.tanh(self.bias_theta)
        p1 = -2 * cosangle * modulus
        p2 = modulus.pow(2)
M. François's avatar
M. François committed
169 170
        p1 = p1.detach().cpu().numpy()
        p2 = p2.detach().cpu().numpy()
171
        p1 = p1.reshape(p1.size, 1)
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
        p2 = p2.reshape(p2.size, 1)
        one = np.ones(p1.shape)

        denom = np.concatenate((one, p1, p2), axis=1)
        return denom

    @property
    def gradients(self):
        mod_grad = self.bias_modulus.grad
        if mod_grad is not None:
            mod_grad = mod_grad.data.numpy()
            mod_grad = mod_grad.reshape(mod_grad.size, 1)
            cos_grad = self.bias_theta.grad.data.numpy()
            cos_grad = cos_grad.reshape(cos_grad.size, 1)
            return np.concatenate((mod_grad, cos_grad), axis=1)
        else:
            return np.zeros((self.hidden_size, 2))