Skip to content
Snippets Groups Projects
Commit eba283db authored by M. Francois's avatar M. Francois
Browse files

Fix circular imports

parent dee6a3c1
Branches
Tags 0.2
No related merge requests found
...@@ -26,73 +26,8 @@ This package implements a trainable all-pole second order filter with complex co ...@@ -26,73 +26,8 @@ This package implements a trainable all-pole second order filter with complex co
import pkg_resources import pkg_resources
from abc import ABC, abstractmethod
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from .lfilter_grad import lfilter
from .neural_filter_2CC import * from .neural_filter_2CC import *
from .zero_pole_filter import * from .zero_pole_filter import *
__version__ = pkg_resources.get_distribution('neural_filters').version __version__ = pkg_resources.get_distribution('neural_filters').version
class FilterBase(nn.Module, ABC):
"""
A trainable filter
"""
def __init__(self, batch_first=False):
super().__init__()
# Transpose dimensions to bring time last
self.tdims = (-1, 0)
# Is the batch dimension the first?
if batch_first:
self.tdims = (-1, 1)
def forward(self, input_var):
"""
Apply the filter
Parameters
----------
input_var : Tensor
dimensions should be:
(time, batch, *features) if batch_first=False
(batch, time, *features) if batch_first=True
Returns
-------
filtered : Tensor
dimensions match the input_var
"""
is_packed = isinstance(input_var, PackedSequence)
if is_packed:
input_var, batch_sizes = pad_packed_sequence(input_var)
# Compute coefficients for numerator and denominator
a_coef, b_coef = self.coeffs()
# Transpose time to last dimension
input_var = input_var.transpose(*self.tdims)
# Apply autograd lfilter
output = lfilter(input_var, a_coef, b_coef)
# Transpose time back in original position
output = output.transpose(*self.tdims)
if is_packed:
output = pack_padded_sequence(output, batch_sizes)
return output
@abstractmethod
def coeffs(self):
pass
from abc import ABC, abstractmethod
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from .lfilter_grad import lfilter
class FilterBase(nn.Module, ABC):
"""
A trainable filter
"""
def __init__(self, batch_first=False):
super().__init__()
# Transpose dimensions to bring time last
self.tdims = (-1, 0)
# Is the batch dimension the first?
if batch_first:
self.tdims = (-1, 1)
def forward(self, input_var):
"""
Apply the filter
Parameters
----------
input_var : Tensor
dimensions should be:
(time, batch, *features) if batch_first=False
(batch, time, *features) if batch_first=True
Returns
-------
filtered : Tensor
dimensions match the input_var
"""
is_packed = isinstance(input_var, PackedSequence)
if is_packed:
input_var, batch_sizes = pad_packed_sequence(input_var)
# Compute coefficients for numerator and denominator
a_coef, b_coef = self.coeffs()
# Transpose time to last dimension
input_var = input_var.transpose(*self.tdims)
# Apply autograd lfilter
output = lfilter(input_var, a_coef, b_coef)
# Transpose time back in original position
output = output.transpose(*self.tdims)
if is_packed:
output = pack_padded_sequence(output, batch_sizes)
return output
@abstractmethod
def coeffs(self):
pass
...@@ -21,7 +21,7 @@ import math ...@@ -21,7 +21,7 @@ import math
import torch import torch
from torch import nn from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from . import FilterBase from .filter_base import FilterBase
class _NeuralFilter2CC(FilterBase): class _NeuralFilter2CC(FilterBase):
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
import torch import torch
from torch import nn from torch import nn
from . import FilterBase from .filter_base import FilterBase
class ZeroPoleFilter(FilterBase): class ZeroPoleFilter(FilterBase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment