Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
N
neural_filters
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
software
neural_filters
Commits
9bcebe31
Commit
9bcebe31
authored
4 years ago
by
M. François
Browse files
Options
Downloads
Patches
Plain Diff
Add initialization, V1
parent
eba283db
Branches
master
Tags
1.0
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
neural_filters/zero_pole_filter.py
+73
-10
73 additions, 10 deletions
neural_filters/zero_pole_filter.py
setup.py
+2
-2
2 additions, 2 deletions
setup.py
with
75 additions
and
12 deletions
neural_filters/zero_pole_filter.py
+
73
−
10
View file @
9bcebe31
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# SPDX-License-Identifier: BSD-3-Clause
import
math
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.filter_base
import
FilterBase
from
.filter_base
import
FilterBase
...
@@ -24,7 +25,7 @@ from .filter_base import FilterBase
...
@@ -24,7 +25,7 @@ from .filter_base import FilterBase
class
ZeroPoleFilter
(
FilterBase
):
class
ZeroPoleFilter
(
FilterBase
):
def
__init__
(
self
,
f
=
None
,
batch_first
=
False
):
def
__init__
(
self
,
theta
=
None
,
r
=
None
,
h
=
None
,
batch_first
=
False
):
super
().
__init__
(
batch_first
)
super
().
__init__
(
batch_first
)
# The hidden values of our parameters
# The hidden values of our parameters
...
@@ -32,17 +33,40 @@ class ZeroPoleFilter(FilterBase):
...
@@ -32,17 +33,40 @@ class ZeroPoleFilter(FilterBase):
self
.
r_hid
=
nn
.
Parameter
(
torch
.
empty
(
1
))
self
.
r_hid
=
nn
.
Parameter
(
torch
.
empty
(
1
))
self
.
h_hid
=
nn
.
Parameter
(
torch
.
empty
(
1
))
self
.
h_hid
=
nn
.
Parameter
(
torch
.
empty
(
1
))
self
.
reset_parameters
(
f
)
self
.
reset_parameters
(
theta
,
r
,
h
)
def
reset_parameters
(
self
,
f
=
None
):
def
reset_parameters
(
self
,
theta
=
None
,
r
=
None
,
h
=
None
):
if
f
is
not
None
:
if
theta
is
not
None
:
# Use f to set initial values of a0, r and h
if
not
isinstance
(
theta
,
torch
.
Tensor
):
# asig and atanh should be used to get back to hidden values
theta
=
torch
.
tensor
(
theta
)
raise
NotImplementedError
()
a0
=
torch
.
cos
(
theta
)
if
h
is
None
:
h
=
torch
.
sin
(
theta
)
a0
=
torch
.
atanh_
(
a0
)
self
.
a0_hid
.
data
.
copy_
(
a0
)
else
:
else
:
nn
.
init
.
uniform_
(
self
.
a0_hid
)
nn
.
init
.
uniform_
(
self
.
a0_hid
)
if
r
is
not
None
:
if
not
isinstance
(
r
,
torch
.
Tensor
):
r
=
torch
.
tensor
(
r
)
r
=
-
torch
.
log
((
1
/
r
)
-
1
)
self
.
r_hid
.
data
.
copy_
(
r
)
else
:
nn
.
init
.
uniform_
(
self
.
r_hid
)
nn
.
init
.
uniform_
(
self
.
r_hid
)
if
h
is
not
None
:
if
not
isinstance
(
h
,
torch
.
Tensor
):
h
=
torch
.
tensor
(
h
)
h
/=
2
h
=
-
torch
.
log
((
1
/
h
)
-
1
)
self
.
h_hid
.
data
.
copy_
(
h
)
else
:
nn
.
init
.
uniform_
(
self
.
h_hid
)
nn
.
init
.
uniform_
(
self
.
h_hid
)
def
coeffs
(
self
):
def
coeffs
(
self
):
...
@@ -66,23 +90,62 @@ class ZeroPoleFilter(FilterBase):
...
@@ -66,23 +90,62 @@ class ZeroPoleFilter(FilterBase):
return
a_coef
,
b_coef
return
a_coef
,
b_coef
@property
def
a0
(
self
):
return
torch
.
tanh
(
self
.
a0_hid
).
item
()
@property
def
c0
(
self
):
a0
=
torch
.
tanh
(
self
.
a0_hid
)
return
torch
.
sqrt
(
1
-
a0
).
item
()
@property
def
h
(
self
):
return
torch
.
sigmoid
(
self
.
h_hid
).
item
()
*
2
@property
def
r
(
self
):
return
torch
.
sigmoid
(
self
.
r_hid
).
item
()
@property
def
f
(
self
):
a0
=
torch
.
tanh
(
self
.
a0_hid
)
return
torch
.
acos
(
a0
).
item
()
/
math
.
pi
def
__repr__
(
self
):
return
'
ZeroPoleFilter (f:{}, r:{}, h:{})
'
.
format
(
self
.
f
,
self
.
r
,
self
.
h
)
class
ZeroPoleLayer
(
nn
.
Module
):
class
ZeroPoleLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_filters
,
greenwood_init
=
False
,
batch_first
=
False
):
def
__init__
(
self
,
n_filters
,
greenwood_init
=
True
,
fs
=
16e3
,
r
=
0.95
,
h
=
0.5
,
batch_first
=
False
):
super
().
__init__
()
super
().
__init__
()
# Fancy init:
# Fancy init:
if
greenwood_init
:
if
greenwood_init
:
x
=
torch
.
linspace
(
0.1
,
0.9
,
n_filters
)
x
=
torch
.
linspace
(
0.1
,
0.9
,
n_filters
)
freqs
=
165.4
*
(
torch
.
pow
(
10
,
2.1
*
x
)
-
1
)
freqs
=
165.4
*
(
torch
.
pow
(
10
,
2.1
*
x
)
-
1
)
thetas
=
freqs
/
fs
*
math
.
pi
else
:
else
:
freqs
=
(
None
,
)
*
n_filters
thetas
=
(
None
,
)
*
n_filters
r
=
None
h
=
None
# Create a list of all filters
# Create a list of all filters
self
.
filters
=
nn
.
ModuleList
(
self
.
filters
=
nn
.
ModuleList
(
[
ZeroPoleFilter
(
f
,
batch_first
)
for
f
in
freq
s
])
[
ZeroPoleFilter
(
theta
,
r
,
h
,
batch_first
)
for
theta
in
theta
s
])
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
outputs
=
[
zpfilt
(
x
)
for
zpfilt
in
self
.
filters
]
outputs
=
[
zpfilt
(
x
)
for
zpfilt
in
self
.
filters
]
return
torch
.
stack
(
outputs
,
-
1
)
return
torch
.
stack
(
outputs
,
-
1
)
class
CascadeZPLayer
(
ZeroPoleLayer
):
def
forward
(
self
,
x
):
outputs
=
[]
for
filt
in
self
.
filters
[::
-
1
]:
x
=
filt
(
x
)
outputs
.
append
(
x
)
return
torch
.
stack
(
outputs
,
-
1
)
This diff is collapsed.
Click to expand it.
setup.py
+
2
−
2
View file @
9bcebe31
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
__version__
=
'
0.2
'
__version__
=
'
1.0
'
setup
(
setup
(
name
=
'
neural_filters
'
,
name
=
'
neural_filters
'
,
...
@@ -11,7 +11,7 @@ setup(
...
@@ -11,7 +11,7 @@ setup(
license
=
'
BSD-3
'
,
license
=
'
BSD-3
'
,
packages
=
find_packages
(),
packages
=
find_packages
(),
install_requires
=
[
install_requires
=
[
'
torch>=1.
8
'
,
'
torch>=1.
6
'
,
'
torchaudio>=0.8
'
'
torchaudio>=0.8
'
],
],
zip_safe
=
True
,
zip_safe
=
True
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment