Skip to content
Snippets Groups Projects
Commit 9bcebe31 authored by M. François's avatar M. François
Browse files

Add initialization, V1

parent eba283db
Branches master
Tags 1.0
No related merge requests found
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
import math
import torch import torch
from torch import nn from torch import nn
from .filter_base import FilterBase from .filter_base import FilterBase
...@@ -24,7 +25,7 @@ from .filter_base import FilterBase ...@@ -24,7 +25,7 @@ from .filter_base import FilterBase
class ZeroPoleFilter(FilterBase): class ZeroPoleFilter(FilterBase):
def __init__(self, f=None, batch_first=False): def __init__(self, theta=None, r=None, h=None, batch_first=False):
super().__init__(batch_first) super().__init__(batch_first)
# The hidden values of our parameters # The hidden values of our parameters
...@@ -32,17 +33,40 @@ class ZeroPoleFilter(FilterBase): ...@@ -32,17 +33,40 @@ class ZeroPoleFilter(FilterBase):
self.r_hid = nn.Parameter(torch.empty(1)) self.r_hid = nn.Parameter(torch.empty(1))
self.h_hid = nn.Parameter(torch.empty(1)) self.h_hid = nn.Parameter(torch.empty(1))
self.reset_parameters(f) self.reset_parameters(theta, r, h)
def reset_parameters(self, f=None): def reset_parameters(self, theta=None, r=None, h=None):
if f is not None: if theta is not None:
# Use f to set initial values of a0, r and h if not isinstance(theta, torch.Tensor):
# asig and atanh should be used to get back to hidden values theta = torch.tensor(theta)
raise NotImplementedError()
a0 = torch.cos(theta)
if h is None:
h = torch.sin(theta)
a0 = torch.atanh_(a0)
self.a0_hid.data.copy_(a0)
else: else:
nn.init.uniform_(self.a0_hid) nn.init.uniform_(self.a0_hid)
if r is not None:
if not isinstance(r, torch.Tensor):
r = torch.tensor(r)
r = -torch.log((1 / r) - 1)
self.r_hid.data.copy_(r)
else:
nn.init.uniform_(self.r_hid) nn.init.uniform_(self.r_hid)
if h is not None:
if not isinstance(h, torch.Tensor):
h = torch.tensor(h)
h /= 2
h = -torch.log((1 / h) - 1)
self.h_hid.data.copy_(h)
else:
nn.init.uniform_(self.h_hid) nn.init.uniform_(self.h_hid)
def coeffs(self): def coeffs(self):
...@@ -66,23 +90,62 @@ class ZeroPoleFilter(FilterBase): ...@@ -66,23 +90,62 @@ class ZeroPoleFilter(FilterBase):
return a_coef, b_coef return a_coef, b_coef
@property
def a0(self):
return torch.tanh(self.a0_hid).item()
@property
def c0(self):
a0 = torch.tanh(self.a0_hid)
return torch.sqrt(1 - a0).item()
@property
def h(self):
return torch.sigmoid(self.h_hid).item() * 2
@property
def r(self):
return torch.sigmoid(self.r_hid).item()
@property
def f(self):
a0 = torch.tanh(self.a0_hid)
return torch.acos(a0).item() / math.pi
def __repr__(self):
return 'ZeroPoleFilter (f:{}, r:{}, h:{})'.format(self.f, self.r, self.h)
class ZeroPoleLayer(nn.Module): class ZeroPoleLayer(nn.Module):
def __init__(self, n_filters, greenwood_init=False, batch_first=False): def __init__(self, n_filters, greenwood_init=True, fs=16e3, r=0.95, h=0.5, batch_first=False):
super().__init__() super().__init__()
# Fancy init: # Fancy init:
if greenwood_init: if greenwood_init:
x = torch.linspace(0.1, 0.9, n_filters) x = torch.linspace(0.1, 0.9, n_filters)
freqs = 165.4 * (torch.pow(10, 2.1 * x) - 1) freqs = 165.4 * (torch.pow(10, 2.1 * x) - 1)
thetas = freqs / fs * math.pi
else: else:
freqs = (None, ) * n_filters thetas = (None, ) * n_filters
r = None
h = None
# Create a list of all filters # Create a list of all filters
self.filters = nn.ModuleList( self.filters = nn.ModuleList(
[ZeroPoleFilter(f, batch_first) for f in freqs]) [ZeroPoleFilter(theta, r, h, batch_first) for theta in thetas])
def forward(self, x): def forward(self, x):
outputs = [zpfilt(x) for zpfilt in self.filters] outputs = [zpfilt(x) for zpfilt in self.filters]
return torch.stack(outputs, -1) return torch.stack(outputs, -1)
class CascadeZPLayer(ZeroPoleLayer):
def forward(self, x):
outputs = []
for filt in self.filters[::-1]:
x = filt(x)
outputs.append(x)
return torch.stack(outputs, -1)
from setuptools import setup, find_packages from setuptools import setup, find_packages
__version__ = '0.2' __version__ = '1.0'
setup( setup(
name='neural_filters', name='neural_filters',
...@@ -11,7 +11,7 @@ setup( ...@@ -11,7 +11,7 @@ setup(
license='BSD-3', license='BSD-3',
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=[
'torch>=1.8', 'torch>=1.6',
'torchaudio>=0.8' 'torchaudio>=0.8'
], ],
zip_safe=True, zip_safe=True,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment