Commit eba283db authored by M. Francois's avatar M. Francois

Fix circular imports

parent dee6a3c1
......@@ -26,73 +26,8 @@ This package implements a trainable all-pole second order filter with complex co
import pkg_resources
from abc import ABC, abstractmethod
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from .lfilter_grad import lfilter
from .neural_filter_2CC import *
from .zero_pole_filter import *
__version__ = pkg_resources.get_distribution('neural_filters').version
class FilterBase(nn.Module, ABC):
"""
A trainable filter
"""
def __init__(self, batch_first=False):
super().__init__()
# Transpose dimensions to bring time last
self.tdims = (-1, 0)
# Is the batch dimension the first?
if batch_first:
self.tdims = (-1, 1)
def forward(self, input_var):
"""
Apply the filter
Parameters
----------
input_var : Tensor
dimensions should be:
(time, batch, *features) if batch_first=False
(batch, time, *features) if batch_first=True
Returns
-------
filtered : Tensor
dimensions match the input_var
"""
is_packed = isinstance(input_var, PackedSequence)
if is_packed:
input_var, batch_sizes = pad_packed_sequence(input_var)
# Compute coefficients for numerator and denominator
a_coef, b_coef = self.coeffs()
# Transpose time to last dimension
input_var = input_var.transpose(*self.tdims)
# Apply autograd lfilter
output = lfilter(input_var, a_coef, b_coef)
# Transpose time back in original position
output = output.transpose(*self.tdims)
if is_packed:
output = pack_padded_sequence(output, batch_sizes)
return output
@abstractmethod
def coeffs(self):
pass
from abc import ABC, abstractmethod
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from .lfilter_grad import lfilter
class FilterBase(nn.Module, ABC):
"""
A trainable filter
"""
def __init__(self, batch_first=False):
super().__init__()
# Transpose dimensions to bring time last
self.tdims = (-1, 0)
# Is the batch dimension the first?
if batch_first:
self.tdims = (-1, 1)
def forward(self, input_var):
"""
Apply the filter
Parameters
----------
input_var : Tensor
dimensions should be:
(time, batch, *features) if batch_first=False
(batch, time, *features) if batch_first=True
Returns
-------
filtered : Tensor
dimensions match the input_var
"""
is_packed = isinstance(input_var, PackedSequence)
if is_packed:
input_var, batch_sizes = pad_packed_sequence(input_var)
# Compute coefficients for numerator and denominator
a_coef, b_coef = self.coeffs()
# Transpose time to last dimension
input_var = input_var.transpose(*self.tdims)
# Apply autograd lfilter
output = lfilter(input_var, a_coef, b_coef)
# Transpose time back in original position
output = output.transpose(*self.tdims)
if is_packed:
output = pack_padded_sequence(output, batch_sizes)
return output
@abstractmethod
def coeffs(self):
pass
......@@ -21,7 +21,7 @@ import math
import torch
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from . import FilterBase
from .filter_base import FilterBase
class _NeuralFilter2CC(FilterBase):
......
......@@ -19,7 +19,7 @@
import torch
from torch import nn
from . import FilterBase
from .filter_base import FilterBase
class ZeroPoleFilter(FilterBase):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment