diff --git a/neural_filters/__init__.py b/neural_filters/__init__.py index adde438c3a95737a669c4071364363d04c219418..380a5a1db5324eca4ff59fefd3dac966bc556cc8 100644 --- a/neural_filters/__init__.py +++ b/neural_filters/__init__.py @@ -26,73 +26,8 @@ This package implements a trainable all-pole second order filter with complex co import pkg_resources -from abc import ABC, abstractmethod - -from torch import nn -from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence - -from .lfilter_grad import lfilter from .neural_filter_2CC import * from .zero_pole_filter import * __version__ = pkg_resources.get_distribution('neural_filters').version - - -class FilterBase(nn.Module, ABC): - """ - A trainable filter - """ - - def __init__(self, batch_first=False): - super().__init__() - - # Transpose dimensions to bring time last - self.tdims = (-1, 0) - - # Is the batch dimension the first? - if batch_first: - self.tdims = (-1, 1) - - def forward(self, input_var): - """ - Apply the filter - - Parameters - ---------- - input_var : Tensor - dimensions should be: - (time, batch, *features) if batch_first=False - (batch, time, *features) if batch_first=True - - - Returns - ------- - filtered : Tensor - dimensions match the input_var - """ - - is_packed = isinstance(input_var, PackedSequence) - if is_packed: - input_var, batch_sizes = pad_packed_sequence(input_var) - - # Compute coefficients for numerator and denominator - a_coef, b_coef = self.coeffs() - - # Transpose time to last dimension - input_var = input_var.transpose(*self.tdims) - - # Apply autograd lfilter - output = lfilter(input_var, a_coef, b_coef) - - # Transpose time back in original position - output = output.transpose(*self.tdims) - - if is_packed: - output = pack_padded_sequence(output, batch_sizes) - - return output - - @abstractmethod - def coeffs(self): - pass diff --git a/neural_filters/filter_base.py b/neural_filters/filter_base.py new file mode 100644 index 0000000000000000000000000000000000000000..365b5356014e991aa5061a48c9ad355bd50cfa7d --- /dev/null +++ b/neural_filters/filter_base.py @@ -0,0 +1,66 @@ + +from abc import ABC, abstractmethod + +from torch import nn +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence + +from .lfilter_grad import lfilter + + +class FilterBase(nn.Module, ABC): + """ + A trainable filter + """ + + def __init__(self, batch_first=False): + super().__init__() + + # Transpose dimensions to bring time last + self.tdims = (-1, 0) + + # Is the batch dimension the first? + if batch_first: + self.tdims = (-1, 1) + + def forward(self, input_var): + """ + Apply the filter + + Parameters + ---------- + input_var : Tensor + dimensions should be: + (time, batch, *features) if batch_first=False + (batch, time, *features) if batch_first=True + + + Returns + ------- + filtered : Tensor + dimensions match the input_var + """ + + is_packed = isinstance(input_var, PackedSequence) + if is_packed: + input_var, batch_sizes = pad_packed_sequence(input_var) + + # Compute coefficients for numerator and denominator + a_coef, b_coef = self.coeffs() + + # Transpose time to last dimension + input_var = input_var.transpose(*self.tdims) + + # Apply autograd lfilter + output = lfilter(input_var, a_coef, b_coef) + + # Transpose time back in original position + output = output.transpose(*self.tdims) + + if is_packed: + output = pack_padded_sequence(output, batch_sizes) + + return output + + @abstractmethod + def coeffs(self): + pass diff --git a/neural_filters/neural_filter_2CC.py b/neural_filters/neural_filter_2CC.py index 0cd79f95678b55610b3924b3eae9aba2330e5506..17aefc4e5e52af28cf078a2a017f8548805355c9 100644 --- a/neural_filters/neural_filter_2CC.py +++ b/neural_filters/neural_filter_2CC.py @@ -21,7 +21,7 @@ import math import torch from torch import nn from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence -from . import FilterBase +from .filter_base import FilterBase class _NeuralFilter2CC(FilterBase): diff --git a/neural_filters/zero_pole_filter.py b/neural_filters/zero_pole_filter.py index e120f92651eef67087eeb09b732e43e7e2cb8655..69499576a58113c5b5a5592416d04a4ed4ccd2ca 100644 --- a/neural_filters/zero_pole_filter.py +++ b/neural_filters/zero_pole_filter.py @@ -19,7 +19,7 @@ import torch from torch import nn -from . import FilterBase +from .filter_base import FilterBase class ZeroPoleFilter(FilterBase):