Skip to content
Snippets Groups Projects
Commit 9bcebe31 authored by M. François's avatar M. François
Browse files

Add initialization, V1

parent eba283db
Branches master
Tags 1.0
No related merge requests found
......@@ -17,6 +17,7 @@
# SPDX-License-Identifier: BSD-3-Clause
import math
import torch
from torch import nn
from .filter_base import FilterBase
......@@ -24,7 +25,7 @@ from .filter_base import FilterBase
class ZeroPoleFilter(FilterBase):
def __init__(self, f=None, batch_first=False):
def __init__(self, theta=None, r=None, h=None, batch_first=False):
super().__init__(batch_first)
# The hidden values of our parameters
......@@ -32,17 +33,40 @@ class ZeroPoleFilter(FilterBase):
self.r_hid = nn.Parameter(torch.empty(1))
self.h_hid = nn.Parameter(torch.empty(1))
self.reset_parameters(f)
self.reset_parameters(theta, r, h)
def reset_parameters(self, f=None):
if f is not None:
# Use f to set initial values of a0, r and h
# asig and atanh should be used to get back to hidden values
raise NotImplementedError()
def reset_parameters(self, theta=None, r=None, h=None):
if theta is not None:
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta)
a0 = torch.cos(theta)
if h is None:
h = torch.sin(theta)
a0 = torch.atanh_(a0)
self.a0_hid.data.copy_(a0)
else:
nn.init.uniform_(self.a0_hid)
if r is not None:
if not isinstance(r, torch.Tensor):
r = torch.tensor(r)
r = -torch.log((1 / r) - 1)
self.r_hid.data.copy_(r)
else:
nn.init.uniform_(self.r_hid)
if h is not None:
if not isinstance(h, torch.Tensor):
h = torch.tensor(h)
h /= 2
h = -torch.log((1 / h) - 1)
self.h_hid.data.copy_(h)
else:
nn.init.uniform_(self.h_hid)
def coeffs(self):
......@@ -66,23 +90,62 @@ class ZeroPoleFilter(FilterBase):
return a_coef, b_coef
@property
def a0(self):
return torch.tanh(self.a0_hid).item()
@property
def c0(self):
a0 = torch.tanh(self.a0_hid)
return torch.sqrt(1 - a0).item()
@property
def h(self):
return torch.sigmoid(self.h_hid).item() * 2
@property
def r(self):
return torch.sigmoid(self.r_hid).item()
@property
def f(self):
a0 = torch.tanh(self.a0_hid)
return torch.acos(a0).item() / math.pi
def __repr__(self):
return 'ZeroPoleFilter (f:{}, r:{}, h:{})'.format(self.f, self.r, self.h)
class ZeroPoleLayer(nn.Module):
def __init__(self, n_filters, greenwood_init=False, batch_first=False):
def __init__(self, n_filters, greenwood_init=True, fs=16e3, r=0.95, h=0.5, batch_first=False):
super().__init__()
# Fancy init:
if greenwood_init:
x = torch.linspace(0.1, 0.9, n_filters)
freqs = 165.4 * (torch.pow(10, 2.1 * x) - 1)
thetas = freqs / fs * math.pi
else:
freqs = (None, ) * n_filters
thetas = (None, ) * n_filters
r = None
h = None
# Create a list of all filters
self.filters = nn.ModuleList(
[ZeroPoleFilter(f, batch_first) for f in freqs])
[ZeroPoleFilter(theta, r, h, batch_first) for theta in thetas])
def forward(self, x):
outputs = [zpfilt(x) for zpfilt in self.filters]
return torch.stack(outputs, -1)
class CascadeZPLayer(ZeroPoleLayer):
def forward(self, x):
outputs = []
for filt in self.filters[::-1]:
x = filt(x)
outputs.append(x)
return torch.stack(outputs, -1)
from setuptools import setup, find_packages
__version__ = '0.2'
__version__ = '1.0'
setup(
name='neural_filters',
......@@ -11,7 +11,7 @@ setup(
license='BSD-3',
packages=find_packages(),
install_requires=[
'torch>=1.8',
'torch>=1.6',
'torchaudio>=0.8'
],
zip_safe=True,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment