Skip to content
Snippets Groups Projects
Commit eba283db authored by M. Francois's avatar M. Francois
Browse files

Fix circular imports

parent dee6a3c1
Branches
Tags 0.2
No related merge requests found
......@@ -26,73 +26,8 @@ This package implements a trainable all-pole second order filter with complex co
import pkg_resources
from abc import ABC, abstractmethod
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from .lfilter_grad import lfilter
from .neural_filter_2CC import *
from .zero_pole_filter import *
__version__ = pkg_resources.get_distribution('neural_filters').version
class FilterBase(nn.Module, ABC):
"""
A trainable filter
"""
def __init__(self, batch_first=False):
super().__init__()
# Transpose dimensions to bring time last
self.tdims = (-1, 0)
# Is the batch dimension the first?
if batch_first:
self.tdims = (-1, 1)
def forward(self, input_var):
"""
Apply the filter
Parameters
----------
input_var : Tensor
dimensions should be:
(time, batch, *features) if batch_first=False
(batch, time, *features) if batch_first=True
Returns
-------
filtered : Tensor
dimensions match the input_var
"""
is_packed = isinstance(input_var, PackedSequence)
if is_packed:
input_var, batch_sizes = pad_packed_sequence(input_var)
# Compute coefficients for numerator and denominator
a_coef, b_coef = self.coeffs()
# Transpose time to last dimension
input_var = input_var.transpose(*self.tdims)
# Apply autograd lfilter
output = lfilter(input_var, a_coef, b_coef)
# Transpose time back in original position
output = output.transpose(*self.tdims)
if is_packed:
output = pack_padded_sequence(output, batch_sizes)
return output
@abstractmethod
def coeffs(self):
pass
from abc import ABC, abstractmethod
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from .lfilter_grad import lfilter
class FilterBase(nn.Module, ABC):
"""
A trainable filter
"""
def __init__(self, batch_first=False):
super().__init__()
# Transpose dimensions to bring time last
self.tdims = (-1, 0)
# Is the batch dimension the first?
if batch_first:
self.tdims = (-1, 1)
def forward(self, input_var):
"""
Apply the filter
Parameters
----------
input_var : Tensor
dimensions should be:
(time, batch, *features) if batch_first=False
(batch, time, *features) if batch_first=True
Returns
-------
filtered : Tensor
dimensions match the input_var
"""
is_packed = isinstance(input_var, PackedSequence)
if is_packed:
input_var, batch_sizes = pad_packed_sequence(input_var)
# Compute coefficients for numerator and denominator
a_coef, b_coef = self.coeffs()
# Transpose time to last dimension
input_var = input_var.transpose(*self.tdims)
# Apply autograd lfilter
output = lfilter(input_var, a_coef, b_coef)
# Transpose time back in original position
output = output.transpose(*self.tdims)
if is_packed:
output = pack_padded_sequence(output, batch_sizes)
return output
@abstractmethod
def coeffs(self):
pass
......@@ -21,7 +21,7 @@ import math
import torch
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from . import FilterBase
from .filter_base import FilterBase
class _NeuralFilter2CC(FilterBase):
......
......@@ -19,7 +19,7 @@
import torch
from torch import nn
from . import FilterBase
from .filter_base import FilterBase
class ZeroPoleFilter(FilterBase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment