diff --git a/neural_filters/zero_pole_filter.py b/neural_filters/zero_pole_filter.py index 69499576a58113c5b5a5592416d04a4ed4ccd2ca..3f177142ba7ffb4416944d5e766853c5701f0fde 100644 --- a/neural_filters/zero_pole_filter.py +++ b/neural_filters/zero_pole_filter.py @@ -17,6 +17,7 @@ # SPDX-License-Identifier: BSD-3-Clause +import math import torch from torch import nn from .filter_base import FilterBase @@ -24,7 +25,7 @@ from .filter_base import FilterBase class ZeroPoleFilter(FilterBase): - def __init__(self, f=None, batch_first=False): + def __init__(self, theta=None, r=None, h=None, batch_first=False): super().__init__(batch_first) # The hidden values of our parameters @@ -32,17 +33,40 @@ class ZeroPoleFilter(FilterBase): self.r_hid = nn.Parameter(torch.empty(1)) self.h_hid = nn.Parameter(torch.empty(1)) - self.reset_parameters(f) + self.reset_parameters(theta, r, h) - def reset_parameters(self, f=None): - if f is not None: - # Use f to set initial values of a0, r and h - # asig and atanh should be used to get back to hidden values - raise NotImplementedError() + def reset_parameters(self, theta=None, r=None, h=None): + if theta is not None: + if not isinstance(theta, torch.Tensor): + theta = torch.tensor(theta) + a0 = torch.cos(theta) + + if h is None: + h = torch.sin(theta) + + a0 = torch.atanh_(a0) + self.a0_hid.data.copy_(a0) else: nn.init.uniform_(self.a0_hid) + + if r is not None: + if not isinstance(r, torch.Tensor): + r = torch.tensor(r) + + r = -torch.log((1 / r) - 1) + self.r_hid.data.copy_(r) + else: nn.init.uniform_(self.r_hid) + + if h is not None: + if not isinstance(h, torch.Tensor): + h = torch.tensor(h) + + h /= 2 + h = -torch.log((1 / h) - 1) + self.h_hid.data.copy_(h) + else: nn.init.uniform_(self.h_hid) def coeffs(self): @@ -66,23 +90,62 @@ class ZeroPoleFilter(FilterBase): return a_coef, b_coef + @property + def a0(self): + return torch.tanh(self.a0_hid).item() + + @property + def c0(self): + a0 = torch.tanh(self.a0_hid) + return torch.sqrt(1 - a0).item() + + @property + def h(self): + return torch.sigmoid(self.h_hid).item() * 2 + + @property + def r(self): + return torch.sigmoid(self.r_hid).item() + + @property + def f(self): + a0 = torch.tanh(self.a0_hid) + return torch.acos(a0).item() / math.pi + + def __repr__(self): + return 'ZeroPoleFilter (f:{}, r:{}, h:{})'.format(self.f, self.r, self.h) + class ZeroPoleLayer(nn.Module): - def __init__(self, n_filters, greenwood_init=False, batch_first=False): + def __init__(self, n_filters, greenwood_init=True, fs=16e3, r=0.95, h=0.5, batch_first=False): super().__init__() # Fancy init: if greenwood_init: x = torch.linspace(0.1, 0.9, n_filters) freqs = 165.4 * (torch.pow(10, 2.1 * x) - 1) + thetas = freqs / fs * math.pi else: - freqs = (None, ) * n_filters + thetas = (None, ) * n_filters + r = None + h = None # Create a list of all filters self.filters = nn.ModuleList( - [ZeroPoleFilter(f, batch_first) for f in freqs]) + [ZeroPoleFilter(theta, r, h, batch_first) for theta in thetas]) def forward(self, x): outputs = [zpfilt(x) for zpfilt in self.filters] return torch.stack(outputs, -1) + + +class CascadeZPLayer(ZeroPoleLayer): + def forward(self, x): + outputs = [] + + for filt in self.filters[::-1]: + x = filt(x) + outputs.append(x) + + return torch.stack(outputs, -1) diff --git a/setup.py b/setup.py index ecd8e86640ab3150d6f45fbc56aed70dcb9cb3c9..5da52fa16a9369338d06300fee00b74478e3c0c1 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -__version__ = '0.2' +__version__ = '1.0' setup( name='neural_filters', @@ -11,7 +11,7 @@ setup( license='BSD-3', packages=find_packages(), install_requires=[ - 'torch>=1.8', + 'torch>=1.6', 'torchaudio>=0.8' ], zip_safe=True,