Commit 9bcebe31 authored by M. François's avatar M. François

Add initialization, V1

parent eba283db
......@@ -17,6 +17,7 @@
# SPDX-License-Identifier: BSD-3-Clause
import math
import torch
from torch import nn
from .filter_base import FilterBase
......@@ -24,7 +25,7 @@ from .filter_base import FilterBase
class ZeroPoleFilter(FilterBase):
def __init__(self, f=None, batch_first=False):
def __init__(self, theta=None, r=None, h=None, batch_first=False):
super().__init__(batch_first)
# The hidden values of our parameters
......@@ -32,17 +33,40 @@ class ZeroPoleFilter(FilterBase):
self.r_hid = nn.Parameter(torch.empty(1))
self.h_hid = nn.Parameter(torch.empty(1))
self.reset_parameters(f)
self.reset_parameters(theta, r, h)
def reset_parameters(self, f=None):
if f is not None:
# Use f to set initial values of a0, r and h
# asig and atanh should be used to get back to hidden values
raise NotImplementedError()
def reset_parameters(self, theta=None, r=None, h=None):
if theta is not None:
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta)
a0 = torch.cos(theta)
if h is None:
h = torch.sin(theta)
a0 = torch.atanh_(a0)
self.a0_hid.data.copy_(a0)
else:
nn.init.uniform_(self.a0_hid)
if r is not None:
if not isinstance(r, torch.Tensor):
r = torch.tensor(r)
r = -torch.log((1 / r) - 1)
self.r_hid.data.copy_(r)
else:
nn.init.uniform_(self.r_hid)
if h is not None:
if not isinstance(h, torch.Tensor):
h = torch.tensor(h)
h /= 2
h = -torch.log((1 / h) - 1)
self.h_hid.data.copy_(h)
else:
nn.init.uniform_(self.h_hid)
def coeffs(self):
......@@ -66,23 +90,62 @@ class ZeroPoleFilter(FilterBase):
return a_coef, b_coef
@property
def a0(self):
return torch.tanh(self.a0_hid).item()
@property
def c0(self):
a0 = torch.tanh(self.a0_hid)
return torch.sqrt(1 - a0).item()
@property
def h(self):
return torch.sigmoid(self.h_hid).item() * 2
@property
def r(self):
return torch.sigmoid(self.r_hid).item()
@property
def f(self):
a0 = torch.tanh(self.a0_hid)
return torch.acos(a0).item() / math.pi
def __repr__(self):
return 'ZeroPoleFilter (f:{}, r:{}, h:{})'.format(self.f, self.r, self.h)
class ZeroPoleLayer(nn.Module):
def __init__(self, n_filters, greenwood_init=False, batch_first=False):
def __init__(self, n_filters, greenwood_init=True, fs=16e3, r=0.95, h=0.5, batch_first=False):
super().__init__()
# Fancy init:
if greenwood_init:
x = torch.linspace(0.1, 0.9, n_filters)
freqs = 165.4 * (torch.pow(10, 2.1 * x) - 1)
thetas = freqs / fs * math.pi
else:
freqs = (None, ) * n_filters
thetas = (None, ) * n_filters
r = None
h = None
# Create a list of all filters
self.filters = nn.ModuleList(
[ZeroPoleFilter(f, batch_first) for f in freqs])
[ZeroPoleFilter(theta, r, h, batch_first) for theta in thetas])
def forward(self, x):
outputs = [zpfilt(x) for zpfilt in self.filters]
return torch.stack(outputs, -1)
class CascadeZPLayer(ZeroPoleLayer):
def forward(self, x):
outputs = []
for filt in self.filters[::-1]:
x = filt(x)
outputs.append(x)
return torch.stack(outputs, -1)
from setuptools import setup, find_packages
__version__ = '0.2'
__version__ = '1.0'
setup(
name='neural_filters',
......@@ -11,7 +11,7 @@ setup(
license='BSD-3',
packages=find_packages(),
install_requires=[
'torch>=1.8',
'torch>=1.6',
'torchaudio>=0.8'
],
zip_safe=True,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment