diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py
index 2617ec0f2a5dfc7e294035a8debf153cfe83062b..516af66bcae176129d985555b6b7facd885ce5b4 100644
--- a/src/ptbench/data/__init__.py
+++ b/src/ptbench/data/__init__.py
@@ -1 +1,345 @@
 """Data manipulation and raw dataset definitions."""
+
+import random
+
+import torch
+
+from clapper.logging import setup
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
+    """Creates a new data set, applying transforms.
+
+    .. note::
+
+       This is a convenience function for our own dataset definitions inside
+       this module, guaranteeting homogenity between dataset definitions
+       provided in this package.  It assumes certain strategies for data
+       augmentation that may not be translatable to other applications.
+
+
+    Parameters
+    ----------
+
+    samples : list
+        List of delayed samples
+
+    transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+
+    prefixes : list
+        A list of data augmentation operations that needs to be applied
+        **before** the transforms above
+
+    suffixes : list
+        A list of data augmentation operations that needs to be applied
+        **after** the transforms above
+
+
+    Returns
+    -------
+
+    subset : :py:class:`ptbench.data.utils.SampleListDataset`
+        A pre-formatted dataset that can be fed to one of our engines
+    """
+    from .utils import SampleListDataset as wrapper
+
+    return wrapper(samples, prefixes + transforms + suffixes)
+
+
+def make_dataset(
+    subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
+):
+    """Creates a new configuration dataset from a list of dictionaries and
+    transforms.
+
+    This function takes as input a list of dictionaries as those that can be
+    returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
+    mapping protocol names (such as ``train``, ``dev`` and ``test``) to
+    :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
+    transforms, and returns a dictionary applying
+    :py:class:`ptbench.data.utils.SampleListDataset` to these
+    lists, and our standard data augmentation if a ``train`` set exists.
+
+    For example, if ``subsets`` is composed of two sets named ``train`` and
+    ``test``, this function will yield a dictionary with the following entries:
+
+    * ``__train__``: Wraps the ``train`` subset, includes data augmentation
+      (note: datasets with names starting with ``_`` (underscore) are excluded
+      from prediction and evaluation by default, as they contain data
+      augmentation transformations.)
+    * ``train``: Wraps the ``train`` subset, **without** data augmentation
+    * ``test``: Wraps the ``test`` subset, **without** data augmentation
+
+    .. note::
+
+       This is a convenience function for our own dataset definitions inside
+       this module, guaranteeting homogenity between dataset definitions
+       provided in this package.  It assumes certain strategies for data
+       augmentation that may not be translatable to other applications.
+
+
+    Parameters
+    ----------
+
+    subsets : list
+        A list of dictionaries that contains the delayed sample lists
+        for a number of named lists. The subsets will be aggregated in one
+        final subset. If one of the keys is ``train``, our standard dataset
+        augmentation transforms are appended to the definition of that subset.
+        All other subsets remain un-augmented.
+
+    transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+
+    t_transforms : list
+        A list of transforms that needs to be applied to the train samples
+
+    post_transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+        after all the other transforms
+
+
+    Returns
+    -------
+
+    dataset : dict
+        A pre-formatted dataset that can be fed to one of our engines. It maps
+        string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
+    """
+
+    retval = {}
+
+    if len(subsets_groups) == 1:
+        subsets = subsets_groups[0]
+    else:
+        # If multiple subsets groups: aggregation
+        aggregated_subsets = {}
+        for subsets in subsets_groups:
+            for key in subsets.keys():
+                if key in aggregated_subsets:
+                    aggregated_subsets[key] += subsets[key]
+                    # Shuffle if data comes from multiple datasets
+                    random.shuffle(aggregated_subsets[key])
+                else:
+                    aggregated_subsets[key] = subsets[key]
+        subsets = aggregated_subsets
+
+    # Add post_transforms after t_transforms for the train set
+    t_transforms += post_transforms
+
+    for key in subsets.keys():
+        retval[key] = make_subset(
+            subsets[key], transforms=transforms, suffixes=post_transforms
+        )
+        if key == "train":
+            retval["__train__"] = make_subset(
+                subsets[key], transforms=transforms, suffixes=(t_transforms)
+            )
+        if key == "validation":
+            # also use it for validation during training
+            retval["__valid__"] = retval[key]
+
+    if (
+        ("__train__" in retval)
+        and ("train" in retval)
+        and ("__valid__" not in retval)
+    ):
+        # if the dataset does not have a validation set, we use the unaugmented
+        # training set as validation set
+        retval["__valid__"] = retval["train"]
+
+    return retval
+
+
+def get_samples_weights(dataset):
+    """Compute the weights of all the samples of the dataset to balance it
+    using the sampler of the dataloader.
+
+    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    and computes the weights to balance each class in the dataset and the
+    datasets themselves if we have a ConcatDataset.
+
+
+    Parameters
+    ----------
+
+    dataset : torch.utils.data.dataset.Dataset
+        An instance of torch.utils.data.dataset.Dataset
+        ConcatDataset are supported
+
+
+    Returns
+    -------
+
+    samples_weights : :py:class:`torch.Tensor`
+        the weights for all the samples in the dataset given as input
+    """
+    samples_weights = []
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            # Weighting only for binary labels
+            if isinstance(ds._samples[0].label, int):
+                # Groundtruth
+                targets = []
+                for s in ds._samples:
+                    targets.append(s.label)
+                targets = torch.tensor(targets)
+
+                # Count number of samples per class
+                class_sample_count = torch.tensor(
+                    [
+                        (targets == t).sum()
+                        for t in torch.unique(targets, sorted=True)
+                    ]
+                )
+
+                weight = 1.0 / class_sample_count.float()
+
+                samples_weights.append(
+                    torch.tensor([weight[t] for t in targets])
+                )
+
+            else:
+                # We only weight to sample equally from each dataset
+                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
+
+        # Concatenate sample weights from all the datasets
+        samples_weights = torch.cat(samples_weights)
+
+    else:
+        # Weighting only for binary labels
+        if isinstance(dataset._samples[0].label, int):
+            # Groundtruth
+            targets = []
+            for s in dataset._samples:
+                targets.append(s.label)
+            targets = torch.tensor(targets)
+
+            # Count number of samples per class
+            class_sample_count = torch.tensor(
+                [
+                    (targets == t).sum()
+                    for t in torch.unique(targets, sorted=True)
+                ]
+            )
+
+            weight = 1.0 / class_sample_count.float()
+
+            samples_weights = torch.tensor([weight[t] for t in targets])
+
+        else:
+            # Equal weights for non-binary labels
+            samples_weights = torch.ones(len(dataset._samples))
+
+    return samples_weights
+
+
+def get_positive_weights(dataset):
+    """Compute the positive weights of each class of the dataset to balance the
+    BCEWithLogitsLoss criterion.
+
+    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    and computes the positive weights of each class to use them to have
+    a balanced loss.
+
+
+    Parameters
+    ----------
+
+    dataset : torch.utils.data.dataset.Dataset
+        An instance of torch.utils.data.dataset.Dataset
+        ConcatDataset are supported
+
+
+    Returns
+    -------
+
+    positive_weights : :py:class:`torch.Tensor`
+        the positive weight of each class in the dataset given as input
+    """
+    targets = []
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            for s in ds._samples:
+                targets.append(s.label)
+
+    else:
+        for s in dataset._samples:
+            targets.append(s.label)
+
+    targets = torch.tensor(targets)
+
+    # Binary labels
+    if len(list(targets.shape)) == 1:
+        class_sample_count = [
+            float((targets == t).sum().item())
+            for t in torch.unique(targets, sorted=True)
+        ]
+
+        # Divide negatives by positives
+        positive_weights = torch.tensor(
+            [class_sample_count[0] / class_sample_count[1]]
+        ).reshape(-1)
+
+    # Multiclass labels
+    else:
+        class_sample_count = torch.sum(targets, dim=0)
+        negative_class_sample_count = (
+            torch.full((targets.size()[1],), float(targets.size()[0]))
+            - class_sample_count
+        )
+
+        positive_weights = negative_class_sample_count / (
+            class_sample_count + negative_class_sample_count
+        )
+
+    return positive_weights
+
+
+def return_subsets(dataset):
+    train_dataset = None
+    validation_dataset = None
+    extra_validation_datasets = None
+    predict_dataset = None
+
+    if "__train__" in dataset:
+        logger.info("Found (dedicated) '__train__' set for training")
+        train_dataset = dataset["__train__"]
+    else:
+        train_dataset = dataset["train"]
+
+    if "__valid__" in dataset:
+        logger.info("Found (dedicated) '__valid__' set for validation")
+        validation_dataset = dataset["__valid__"]
+
+    if "__extra_valid__" in dataset:
+        if not isinstance(dataset["__extra_valid__"], list):
+            raise RuntimeError(
+                f"If present, dataset['__extra_valid__'] must be a list, "
+                f"but you passed a {type(dataset['__extra_valid__'])}, "
+                f"which is invalid."
+            )
+        logger.info(
+            f"Found {len(dataset['__extra_valid__'])} extra validation "
+            f"set(s) to be tracked during training"
+        )
+        logger.info(
+            "Extra validation sets are NOT used for model checkpointing!"
+        )
+        extra_validation_datasets = dataset["__extra_valid__"]
+    else:
+        extra_validation_datasets = None
+
+    predict_dataset = dataset
+
+    return (
+        train_dataset,
+        validation_dataset,
+        extra_validation_datasets,
+        predict_dataset,
+    )
diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py
index 60a004a3afc76239ffae144556feaaa888850ab7..a854e559bc033cd12eff0b10fc246b47edd73b64 100644
--- a/src/ptbench/data/shenzhen/__init__.py
+++ b/src/ptbench/data/shenzhen/__init__.py
@@ -22,13 +22,11 @@ the daily routine using Philips DR Digital Diagnose systems.
 
 import importlib.resources
 import os
-import random
-
-import torch
 
 from clapper.logging import setup
 
 from ...utils.rc import load_rc
+from .. import make_dataset
 from ..dataset import JSONDataset
 from ..loader import load_pil_baw, make_delayed
 
@@ -93,339 +91,3 @@ def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
         [ElasticDeformation(p=0.8)],
         post_transforms,
     )
-
-
-def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
-    """Creates a new data set, applying transforms.
-
-    .. note::
-
-       This is a convenience function for our own dataset definitions inside
-       this module, guaranteeting homogenity between dataset definitions
-       provided in this package.  It assumes certain strategies for data
-       augmentation that may not be translatable to other applications.
-
-
-    Parameters
-    ----------
-
-    samples : list
-        List of delayed samples
-
-    transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-
-    prefixes : list
-        A list of data augmentation operations that needs to be applied
-        **before** the transforms above
-
-    suffixes : list
-        A list of data augmentation operations that needs to be applied
-        **after** the transforms above
-
-
-    Returns
-    -------
-
-    subset : :py:class:`ptbench.data.utils.SampleListDataset`
-        A pre-formatted dataset that can be fed to one of our engines
-    """
-    from ...data.utils import SampleListDataset as wrapper
-
-    return wrapper(samples, prefixes + transforms + suffixes)
-
-
-def make_dataset(
-    subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
-):
-    """Creates a new configuration dataset from a list of dictionaries and
-    transforms.
-
-    This function takes as input a list of dictionaries as those that can be
-    returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
-    mapping protocol names (such as ``train``, ``dev`` and ``test``) to
-    :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
-    transforms, and returns a dictionary applying
-    :py:class:`ptbench.data.utils.SampleListDataset` to these
-    lists, and our standard data augmentation if a ``train`` set exists.
-
-    For example, if ``subsets`` is composed of two sets named ``train`` and
-    ``test``, this function will yield a dictionary with the following entries:
-
-    * ``__train__``: Wraps the ``train`` subset, includes data augmentation
-      (note: datasets with names starting with ``_`` (underscore) are excluded
-      from prediction and evaluation by default, as they contain data
-      augmentation transformations.)
-    * ``train``: Wraps the ``train`` subset, **without** data augmentation
-    * ``test``: Wraps the ``test`` subset, **without** data augmentation
-
-    .. note::
-
-       This is a convenience function for our own dataset definitions inside
-       this module, guaranteeting homogenity between dataset definitions
-       provided in this package.  It assumes certain strategies for data
-       augmentation that may not be translatable to other applications.
-
-
-    Parameters
-    ----------
-
-    subsets : list
-        A list of dictionaries that contains the delayed sample lists
-        for a number of named lists. The subsets will be aggregated in one
-        final subset. If one of the keys is ``train``, our standard dataset
-        augmentation transforms are appended to the definition of that subset.
-        All other subsets remain un-augmented.
-
-    transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-
-    t_transforms : list
-        A list of transforms that needs to be applied to the train samples
-
-    post_transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-        after all the other transforms
-
-
-    Returns
-    -------
-
-    dataset : dict
-        A pre-formatted dataset that can be fed to one of our engines. It maps
-        string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
-    """
-
-    retval = {}
-
-    if len(subsets_groups) == 1:
-        subsets = subsets_groups[0]
-    else:
-        # If multiple subsets groups: aggregation
-        aggregated_subsets = {}
-        for subsets in subsets_groups:
-            for key in subsets.keys():
-                if key in aggregated_subsets:
-                    aggregated_subsets[key] += subsets[key]
-                    # Shuffle if data comes from multiple datasets
-                    random.shuffle(aggregated_subsets[key])
-                else:
-                    aggregated_subsets[key] = subsets[key]
-        subsets = aggregated_subsets
-
-    # Add post_transforms after t_transforms for the train set
-    t_transforms += post_transforms
-
-    for key in subsets.keys():
-        retval[key] = make_subset(
-            subsets[key], transforms=transforms, suffixes=post_transforms
-        )
-        if key == "train":
-            retval["__train__"] = make_subset(
-                subsets[key], transforms=transforms, suffixes=(t_transforms)
-            )
-        if key == "validation":
-            # also use it for validation during training
-            retval["__valid__"] = retval[key]
-
-    if (
-        ("__train__" in retval)
-        and ("train" in retval)
-        and ("__valid__" not in retval)
-    ):
-        # if the dataset does not have a validation set, we use the unaugmented
-        # training set as validation set
-        retval["__valid__"] = retval["train"]
-
-    return retval
-
-
-def get_samples_weights(dataset):
-    """Compute the weights of all the samples of the dataset to balance it
-    using the sampler of the dataloader.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the weights to balance each class in the dataset and the
-    datasets themselves if we have a ConcatDataset.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    samples_weights : :py:class:`torch.Tensor`
-        the weights for all the samples in the dataset given as input
-    """
-    samples_weights = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            # Weighting only for binary labels
-            if isinstance(ds._samples[0].label, int):
-                # Groundtruth
-                targets = []
-                for s in ds._samples:
-                    targets.append(s.label)
-                targets = torch.tensor(targets)
-
-                # Count number of samples per class
-                class_sample_count = torch.tensor(
-                    [
-                        (targets == t).sum()
-                        for t in torch.unique(targets, sorted=True)
-                    ]
-                )
-
-                weight = 1.0 / class_sample_count.float()
-
-                samples_weights.append(
-                    torch.tensor([weight[t] for t in targets])
-                )
-
-            else:
-                # We only weight to sample equally from each dataset
-                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
-
-        # Concatenate sample weights from all the datasets
-        samples_weights = torch.cat(samples_weights)
-
-    else:
-        # Weighting only for binary labels
-        if isinstance(dataset._samples[0].label, int):
-            # Groundtruth
-            targets = []
-            for s in dataset._samples:
-                targets.append(s.label)
-            targets = torch.tensor(targets)
-
-            # Count number of samples per class
-            class_sample_count = torch.tensor(
-                [
-                    (targets == t).sum()
-                    for t in torch.unique(targets, sorted=True)
-                ]
-            )
-
-            weight = 1.0 / class_sample_count.float()
-
-            samples_weights = torch.tensor([weight[t] for t in targets])
-
-        else:
-            # Equal weights for non-binary labels
-            samples_weights = torch.ones(len(dataset._samples))
-
-    return samples_weights
-
-
-def get_positive_weights(dataset):
-    """Compute the positive weights of each class of the dataset to balance the
-    BCEWithLogitsLoss criterion.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the positive weights of each class to use them to have
-    a balanced loss.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    positive_weights : :py:class:`torch.Tensor`
-        the positive weight of each class in the dataset given as input
-    """
-    targets = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            for s in ds._samples:
-                targets.append(s.label)
-
-    else:
-        for s in dataset._samples:
-            targets.append(s.label)
-
-    targets = torch.tensor(targets)
-
-    # Binary labels
-    if len(list(targets.shape)) == 1:
-        class_sample_count = [
-            float((targets == t).sum().item())
-            for t in torch.unique(targets, sorted=True)
-        ]
-
-        # Divide negatives by positives
-        positive_weights = torch.tensor(
-            [class_sample_count[0] / class_sample_count[1]]
-        ).reshape(-1)
-
-    # Multiclass labels
-    else:
-        class_sample_count = torch.sum(targets, dim=0)
-        negative_class_sample_count = (
-            torch.full((targets.size()[1],), float(targets.size()[0]))
-            - class_sample_count
-        )
-
-        positive_weights = negative_class_sample_count / (
-            class_sample_count + negative_class_sample_count
-        )
-
-    return positive_weights
-
-
-def return_subsets(dataset):
-    train_dataset = None
-    validation_dataset = None
-    extra_validation_datasets = None
-    predict_dataset = None
-
-    if "__train__" in dataset:
-        logger.info("Found (dedicated) '__train__' set for training")
-        train_dataset = dataset["__train__"]
-    else:
-        train_dataset = dataset["train"]
-
-    if "__valid__" in dataset:
-        logger.info("Found (dedicated) '__valid__' set for validation")
-        validation_dataset = dataset["__valid__"]
-
-    if "__extra_valid__" in dataset:
-        if not isinstance(dataset["__extra_valid__"], list):
-            raise RuntimeError(
-                f"If present, dataset['__extra_valid__'] must be a list, "
-                f"but you passed a {type(dataset['__extra_valid__'])}, "
-                f"which is invalid."
-            )
-        logger.info(
-            f"Found {len(dataset['__extra_valid__'])} extra validation "
-            f"set(s) to be tracked during training"
-        )
-        logger.info(
-            "Extra validation sets are NOT used for model checkpointing!"
-        )
-        extra_validation_datasets = dataset["__extra_valid__"]
-    else:
-        extra_validation_datasets = None
-
-    predict_dataset = dataset
-
-    return (
-        train_dataset,
-        validation_dataset,
-        extra_validation_datasets,
-        predict_dataset,
-    )
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index 3801efbc2a5f66d467a7681be37819e7fda74ede..bbeabcafa98554f6f8bcae71df287e6c3eda939e 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -12,8 +12,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_0.py b/src/ptbench/data/shenzhen/fold_0.py
index d65d513bcb4cefcfc553e3c6bde864392db01add..5b4d45602d13a0e6bd0e2724a6c3202c1532eef6 100644
--- a/src/ptbench/data/shenzhen/fold_0.py
+++ b/src/ptbench/data/shenzhen/fold_0.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_0_rgb.py b/src/ptbench/data/shenzhen/fold_0_rgb.py
index bcc853dd7c52543d4aa27a4b10083d4407582786..143ef731fae0f7d3c746e384e08e400f36c92511 100644
--- a/src/ptbench/data/shenzhen/fold_0_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_0_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_1.py b/src/ptbench/data/shenzhen/fold_1.py
index b9494f156e69ef9ed2af68d607f8d8db2540381f..f01adef0af4da152ad756036aacb4bf83c5c20cc 100644
--- a/src/ptbench/data/shenzhen/fold_1.py
+++ b/src/ptbench/data/shenzhen/fold_1.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_1_rgb.py b/src/ptbench/data/shenzhen/fold_1_rgb.py
index 01e23967e51878a348b89d14e263ff2c5e1d0503..9d457adfa8835e45c7d5dc993ab23ffc8baafeb9 100644
--- a/src/ptbench/data/shenzhen/fold_1_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_1_rgb.py
@@ -12,8 +12,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_2.py b/src/ptbench/data/shenzhen/fold_2.py
index 8d5cf816ce4ae950f19a71ea0c57662f427e2ee4..04dd656263cc6e71d2c61e3f5e221f7129fa7035 100644
--- a/src/ptbench/data/shenzhen/fold_2.py
+++ b/src/ptbench/data/shenzhen/fold_2.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_2_rgb.py b/src/ptbench/data/shenzhen/fold_2_rgb.py
index baf6752d66093a059f295f0e540a9d3ffd6b681f..37cbe10ebd72057535bacd241ecb4dbfb2bded3d 100644
--- a/src/ptbench/data/shenzhen/fold_2_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_2_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_3.py b/src/ptbench/data/shenzhen/fold_3.py
index b42882ccf663b2635dbc106b817f1956a113863e..b43fcb29c9392d6e9bc5c20d9e93969b33bd2708 100644
--- a/src/ptbench/data/shenzhen/fold_3.py
+++ b/src/ptbench/data/shenzhen/fold_3.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_3_rgb.py b/src/ptbench/data/shenzhen/fold_3_rgb.py
index a02c2b1d4e7b33321aa45d537f9af2155f4ae321..162a3f82d640633ce6d13d25b2a500dc0fb63a54 100644
--- a/src/ptbench/data/shenzhen/fold_3_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_3_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_4.py b/src/ptbench/data/shenzhen/fold_4.py
index a9ad14710ad37c02c34e6575fbac064db64a4054..58e0a2f221e0a6a48b3466c015c9c5790fc6ae3a 100644
--- a/src/ptbench/data/shenzhen/fold_4.py
+++ b/src/ptbench/data/shenzhen/fold_4.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_4_rgb.py b/src/ptbench/data/shenzhen/fold_4_rgb.py
index 3620ba428533a9e97392bde40f77b55b454809bb..0dd4ccf89c553a19030774772447bb66bf0cf9b7 100644
--- a/src/ptbench/data/shenzhen/fold_4_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_4_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_5.py b/src/ptbench/data/shenzhen/fold_5.py
index 426a9d6609e212168b3c2775b791de416a6b7dfd..ff115340f723d33ebdeb49524cc7e3b269738563 100644
--- a/src/ptbench/data/shenzhen/fold_5.py
+++ b/src/ptbench/data/shenzhen/fold_5.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_5_rgb.py b/src/ptbench/data/shenzhen/fold_5_rgb.py
index 29e7013885ac1ba917ce40c7a027c827af54d705..46e255e7c37c1f547d25626970324f88d58ba03e 100644
--- a/src/ptbench/data/shenzhen/fold_5_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_5_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_6.py b/src/ptbench/data/shenzhen/fold_6.py
index fb0a91b896723240f26a283cb76c5a2245470fc0..eb81ae882369a1a99bca60e93b1922c7a9d7e9fc 100644
--- a/src/ptbench/data/shenzhen/fold_6.py
+++ b/src/ptbench/data/shenzhen/fold_6.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_6_rgb.py b/src/ptbench/data/shenzhen/fold_6_rgb.py
index 35e7e6d7e7af0758aa4aaa01f24887e3eea6711f..b9654d08008bd10f0a5d953f2a79870214218651 100644
--- a/src/ptbench/data/shenzhen/fold_6_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_6_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_7.py b/src/ptbench/data/shenzhen/fold_7.py
index 743b344d77b9324611e1490dc47966bb52dc41a4..79b0d1fff4483ef98eed42b6dce6c6b64076bd1e 100644
--- a/src/ptbench/data/shenzhen/fold_7.py
+++ b/src/ptbench/data/shenzhen/fold_7.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_7_rgb.py b/src/ptbench/data/shenzhen/fold_7_rgb.py
index 0a9f83d7a0072aa722a7d1cb881249e4722be36c..8a36acb2c79a4467f5dccad532c17cf528613ab5 100644
--- a/src/ptbench/data/shenzhen/fold_7_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_7_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_8.py b/src/ptbench/data/shenzhen/fold_8.py
index ee7f716714adc35beeab1b4eb450253d408320ad..cf1cd36a8d2d2a40113a775131f9d5b8ba0a092a 100644
--- a/src/ptbench/data/shenzhen/fold_8.py
+++ b/src/ptbench/data/shenzhen/fold_8.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_8_rgb.py b/src/ptbench/data/shenzhen/fold_8_rgb.py
index 1351790faf6234c2d7a3b51c500b2d16c94fd2c0..1aa0bcec76d875d8e9b966b4461228230bdf79f2 100644
--- a/src/ptbench/data/shenzhen/fold_8_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_8_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_9.py b/src/ptbench/data/shenzhen/fold_9.py
index dbd8ab3146dca6bda7048c1d47fe1ee0852129c7..e1bb569d0389a8eb67b0ad186e48c6f98aa7cf57 100644
--- a/src/ptbench/data/shenzhen/fold_9.py
+++ b/src/ptbench/data/shenzhen/fold_9.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_9_rgb.py b/src/ptbench/data/shenzhen/fold_9_rgb.py
index 729141dc3a134e4d5e815441262a58004b98ed80..c0a577df0a0d3667420c32042816f65ba9ad20ce 100644
--- a/src/ptbench/data/shenzhen/fold_9_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_9_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py
index 6186990ec8023df5030ba84a9be627e80a95ba44..7bdb8fe3ce6826fb98d0d6356f2e1b429670a3d1 100644
--- a/src/ptbench/data/shenzhen/rgb.py
+++ b/src/ptbench/data/shenzhen/rgb.py
@@ -12,8 +12,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")