diff --git a/neural_filters/__init__.py b/neural_filters/__init__.py
index adde438c3a95737a669c4071364363d04c219418..380a5a1db5324eca4ff59fefd3dac966bc556cc8 100644
--- a/neural_filters/__init__.py
+++ b/neural_filters/__init__.py
@@ -26,73 +26,8 @@ This package implements a trainable all-pole second order filter with complex co
 
 
 import pkg_resources
-from abc import ABC, abstractmethod
-
-from torch import nn
-from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
-
-from .lfilter_grad import lfilter
 from .neural_filter_2CC import *
 from .zero_pole_filter import *
 
 
 __version__ = pkg_resources.get_distribution('neural_filters').version
-
-
-class FilterBase(nn.Module, ABC):
-    """
-    A trainable filter
-    """
-
-    def __init__(self, batch_first=False):
-        super().__init__()
-
-        # Transpose dimensions to bring time last
-        self.tdims = (-1, 0)
-
-        # Is the batch dimension the first?
-        if batch_first:
-            self.tdims = (-1, 1)
-
-    def forward(self, input_var):
-        """
-        Apply the filter
-
-        Parameters
-        ----------
-        input_var : Tensor
-            dimensions should be:
-                (time, batch, *features) if batch_first=False
-                (batch, time, *features) if batch_first=True
-
-
-        Returns
-        -------
-        filtered : Tensor
-            dimensions match the input_var
-        """
-
-        is_packed = isinstance(input_var, PackedSequence)
-        if is_packed:
-            input_var, batch_sizes = pad_packed_sequence(input_var)
-
-        # Compute coefficients for numerator and denominator
-        a_coef, b_coef = self.coeffs()
-
-        # Transpose time to last dimension
-        input_var = input_var.transpose(*self.tdims)
-
-        # Apply autograd lfilter
-        output = lfilter(input_var, a_coef, b_coef)
-
-        # Transpose time back in original position
-        output = output.transpose(*self.tdims)
-
-        if is_packed:
-            output = pack_padded_sequence(output, batch_sizes)
-
-        return output
-
-    @abstractmethod
-    def coeffs(self):
-        pass
diff --git a/neural_filters/filter_base.py b/neural_filters/filter_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..365b5356014e991aa5061a48c9ad355bd50cfa7d
--- /dev/null
+++ b/neural_filters/filter_base.py
@@ -0,0 +1,66 @@
+
+from abc import ABC, abstractmethod
+
+from torch import nn
+from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
+
+from .lfilter_grad import lfilter
+
+
+class FilterBase(nn.Module, ABC):
+    """
+    A trainable filter
+    """
+
+    def __init__(self, batch_first=False):
+        super().__init__()
+
+        # Transpose dimensions to bring time last
+        self.tdims = (-1, 0)
+
+        # Is the batch dimension the first?
+        if batch_first:
+            self.tdims = (-1, 1)
+
+    def forward(self, input_var):
+        """
+        Apply the filter
+
+        Parameters
+        ----------
+        input_var : Tensor
+            dimensions should be:
+                (time, batch, *features) if batch_first=False
+                (batch, time, *features) if batch_first=True
+
+
+        Returns
+        -------
+        filtered : Tensor
+            dimensions match the input_var
+        """
+
+        is_packed = isinstance(input_var, PackedSequence)
+        if is_packed:
+            input_var, batch_sizes = pad_packed_sequence(input_var)
+
+        # Compute coefficients for numerator and denominator
+        a_coef, b_coef = self.coeffs()
+
+        # Transpose time to last dimension
+        input_var = input_var.transpose(*self.tdims)
+
+        # Apply autograd lfilter
+        output = lfilter(input_var, a_coef, b_coef)
+
+        # Transpose time back in original position
+        output = output.transpose(*self.tdims)
+
+        if is_packed:
+            output = pack_padded_sequence(output, batch_sizes)
+
+        return output
+
+    @abstractmethod
+    def coeffs(self):
+        pass
diff --git a/neural_filters/neural_filter_2CC.py b/neural_filters/neural_filter_2CC.py
index 0cd79f95678b55610b3924b3eae9aba2330e5506..17aefc4e5e52af28cf078a2a017f8548805355c9 100644
--- a/neural_filters/neural_filter_2CC.py
+++ b/neural_filters/neural_filter_2CC.py
@@ -21,7 +21,7 @@ import math
 import torch
 from torch import nn
 from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
-from . import FilterBase
+from .filter_base import FilterBase
 
 
 class _NeuralFilter2CC(FilterBase):
diff --git a/neural_filters/zero_pole_filter.py b/neural_filters/zero_pole_filter.py
index e120f92651eef67087eeb09b732e43e7e2cb8655..69499576a58113c5b5a5592416d04a4ed4ccd2ca 100644
--- a/neural_filters/zero_pole_filter.py
+++ b/neural_filters/zero_pole_filter.py
@@ -19,7 +19,7 @@
 
 import torch
 from torch import nn
-from . import FilterBase
+from .filter_base import FilterBase
 
 
 class ZeroPoleFilter(FilterBase):