From b4b5cdbe714490fd7469fb59afd8857da9450791 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 1 Jun 2023 14:18:40 +0200
Subject: [PATCH] Moved common functions to __init__ in parent directory

---
 src/ptbench/data/__init__.py            | 344 ++++++++++++++++++++++++
 src/ptbench/data/shenzhen/__init__.py   | 340 +----------------------
 src/ptbench/data/shenzhen/default.py    |   3 +-
 src/ptbench/data/shenzhen/fold_0.py     |   3 +-
 src/ptbench/data/shenzhen/fold_0_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_1.py     |   3 +-
 src/ptbench/data/shenzhen/fold_1_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_2.py     |   3 +-
 src/ptbench/data/shenzhen/fold_2_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_3.py     |   3 +-
 src/ptbench/data/shenzhen/fold_3_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_4.py     |   3 +-
 src/ptbench/data/shenzhen/fold_4_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_5.py     |   3 +-
 src/ptbench/data/shenzhen/fold_5_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_6.py     |   3 +-
 src/ptbench/data/shenzhen/fold_6_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_7.py     |   3 +-
 src/ptbench/data/shenzhen/fold_7_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_8.py     |   3 +-
 src/ptbench/data/shenzhen/fold_8_rgb.py |   3 +-
 src/ptbench/data/shenzhen/fold_9.py     |   3 +-
 src/ptbench/data/shenzhen/fold_9_rgb.py |   3 +-
 src/ptbench/data/shenzhen/rgb.py        |   3 +-
 24 files changed, 389 insertions(+), 361 deletions(-)

diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py
index 2617ec0f..516af66b 100644
--- a/src/ptbench/data/__init__.py
+++ b/src/ptbench/data/__init__.py
@@ -1 +1,345 @@
 """Data manipulation and raw dataset definitions."""
+
+import random
+
+import torch
+
+from clapper.logging import setup
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
+    """Creates a new data set, applying transforms.
+
+    .. note::
+
+       This is a convenience function for our own dataset definitions inside
+       this module, guaranteeting homogenity between dataset definitions
+       provided in this package.  It assumes certain strategies for data
+       augmentation that may not be translatable to other applications.
+
+
+    Parameters
+    ----------
+
+    samples : list
+        List of delayed samples
+
+    transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+
+    prefixes : list
+        A list of data augmentation operations that needs to be applied
+        **before** the transforms above
+
+    suffixes : list
+        A list of data augmentation operations that needs to be applied
+        **after** the transforms above
+
+
+    Returns
+    -------
+
+    subset : :py:class:`ptbench.data.utils.SampleListDataset`
+        A pre-formatted dataset that can be fed to one of our engines
+    """
+    from .utils import SampleListDataset as wrapper
+
+    return wrapper(samples, prefixes + transforms + suffixes)
+
+
+def make_dataset(
+    subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
+):
+    """Creates a new configuration dataset from a list of dictionaries and
+    transforms.
+
+    This function takes as input a list of dictionaries as those that can be
+    returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
+    mapping protocol names (such as ``train``, ``dev`` and ``test``) to
+    :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
+    transforms, and returns a dictionary applying
+    :py:class:`ptbench.data.utils.SampleListDataset` to these
+    lists, and our standard data augmentation if a ``train`` set exists.
+
+    For example, if ``subsets`` is composed of two sets named ``train`` and
+    ``test``, this function will yield a dictionary with the following entries:
+
+    * ``__train__``: Wraps the ``train`` subset, includes data augmentation
+      (note: datasets with names starting with ``_`` (underscore) are excluded
+      from prediction and evaluation by default, as they contain data
+      augmentation transformations.)
+    * ``train``: Wraps the ``train`` subset, **without** data augmentation
+    * ``test``: Wraps the ``test`` subset, **without** data augmentation
+
+    .. note::
+
+       This is a convenience function for our own dataset definitions inside
+       this module, guaranteeting homogenity between dataset definitions
+       provided in this package.  It assumes certain strategies for data
+       augmentation that may not be translatable to other applications.
+
+
+    Parameters
+    ----------
+
+    subsets : list
+        A list of dictionaries that contains the delayed sample lists
+        for a number of named lists. The subsets will be aggregated in one
+        final subset. If one of the keys is ``train``, our standard dataset
+        augmentation transforms are appended to the definition of that subset.
+        All other subsets remain un-augmented.
+
+    transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+
+    t_transforms : list
+        A list of transforms that needs to be applied to the train samples
+
+    post_transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+        after all the other transforms
+
+
+    Returns
+    -------
+
+    dataset : dict
+        A pre-formatted dataset that can be fed to one of our engines. It maps
+        string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
+    """
+
+    retval = {}
+
+    if len(subsets_groups) == 1:
+        subsets = subsets_groups[0]
+    else:
+        # If multiple subsets groups: aggregation
+        aggregated_subsets = {}
+        for subsets in subsets_groups:
+            for key in subsets.keys():
+                if key in aggregated_subsets:
+                    aggregated_subsets[key] += subsets[key]
+                    # Shuffle if data comes from multiple datasets
+                    random.shuffle(aggregated_subsets[key])
+                else:
+                    aggregated_subsets[key] = subsets[key]
+        subsets = aggregated_subsets
+
+    # Add post_transforms after t_transforms for the train set
+    t_transforms += post_transforms
+
+    for key in subsets.keys():
+        retval[key] = make_subset(
+            subsets[key], transforms=transforms, suffixes=post_transforms
+        )
+        if key == "train":
+            retval["__train__"] = make_subset(
+                subsets[key], transforms=transforms, suffixes=(t_transforms)
+            )
+        if key == "validation":
+            # also use it for validation during training
+            retval["__valid__"] = retval[key]
+
+    if (
+        ("__train__" in retval)
+        and ("train" in retval)
+        and ("__valid__" not in retval)
+    ):
+        # if the dataset does not have a validation set, we use the unaugmented
+        # training set as validation set
+        retval["__valid__"] = retval["train"]
+
+    return retval
+
+
+def get_samples_weights(dataset):
+    """Compute the weights of all the samples of the dataset to balance it
+    using the sampler of the dataloader.
+
+    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    and computes the weights to balance each class in the dataset and the
+    datasets themselves if we have a ConcatDataset.
+
+
+    Parameters
+    ----------
+
+    dataset : torch.utils.data.dataset.Dataset
+        An instance of torch.utils.data.dataset.Dataset
+        ConcatDataset are supported
+
+
+    Returns
+    -------
+
+    samples_weights : :py:class:`torch.Tensor`
+        the weights for all the samples in the dataset given as input
+    """
+    samples_weights = []
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            # Weighting only for binary labels
+            if isinstance(ds._samples[0].label, int):
+                # Groundtruth
+                targets = []
+                for s in ds._samples:
+                    targets.append(s.label)
+                targets = torch.tensor(targets)
+
+                # Count number of samples per class
+                class_sample_count = torch.tensor(
+                    [
+                        (targets == t).sum()
+                        for t in torch.unique(targets, sorted=True)
+                    ]
+                )
+
+                weight = 1.0 / class_sample_count.float()
+
+                samples_weights.append(
+                    torch.tensor([weight[t] for t in targets])
+                )
+
+            else:
+                # We only weight to sample equally from each dataset
+                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
+
+        # Concatenate sample weights from all the datasets
+        samples_weights = torch.cat(samples_weights)
+
+    else:
+        # Weighting only for binary labels
+        if isinstance(dataset._samples[0].label, int):
+            # Groundtruth
+            targets = []
+            for s in dataset._samples:
+                targets.append(s.label)
+            targets = torch.tensor(targets)
+
+            # Count number of samples per class
+            class_sample_count = torch.tensor(
+                [
+                    (targets == t).sum()
+                    for t in torch.unique(targets, sorted=True)
+                ]
+            )
+
+            weight = 1.0 / class_sample_count.float()
+
+            samples_weights = torch.tensor([weight[t] for t in targets])
+
+        else:
+            # Equal weights for non-binary labels
+            samples_weights = torch.ones(len(dataset._samples))
+
+    return samples_weights
+
+
+def get_positive_weights(dataset):
+    """Compute the positive weights of each class of the dataset to balance the
+    BCEWithLogitsLoss criterion.
+
+    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    and computes the positive weights of each class to use them to have
+    a balanced loss.
+
+
+    Parameters
+    ----------
+
+    dataset : torch.utils.data.dataset.Dataset
+        An instance of torch.utils.data.dataset.Dataset
+        ConcatDataset are supported
+
+
+    Returns
+    -------
+
+    positive_weights : :py:class:`torch.Tensor`
+        the positive weight of each class in the dataset given as input
+    """
+    targets = []
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            for s in ds._samples:
+                targets.append(s.label)
+
+    else:
+        for s in dataset._samples:
+            targets.append(s.label)
+
+    targets = torch.tensor(targets)
+
+    # Binary labels
+    if len(list(targets.shape)) == 1:
+        class_sample_count = [
+            float((targets == t).sum().item())
+            for t in torch.unique(targets, sorted=True)
+        ]
+
+        # Divide negatives by positives
+        positive_weights = torch.tensor(
+            [class_sample_count[0] / class_sample_count[1]]
+        ).reshape(-1)
+
+    # Multiclass labels
+    else:
+        class_sample_count = torch.sum(targets, dim=0)
+        negative_class_sample_count = (
+            torch.full((targets.size()[1],), float(targets.size()[0]))
+            - class_sample_count
+        )
+
+        positive_weights = negative_class_sample_count / (
+            class_sample_count + negative_class_sample_count
+        )
+
+    return positive_weights
+
+
+def return_subsets(dataset):
+    train_dataset = None
+    validation_dataset = None
+    extra_validation_datasets = None
+    predict_dataset = None
+
+    if "__train__" in dataset:
+        logger.info("Found (dedicated) '__train__' set for training")
+        train_dataset = dataset["__train__"]
+    else:
+        train_dataset = dataset["train"]
+
+    if "__valid__" in dataset:
+        logger.info("Found (dedicated) '__valid__' set for validation")
+        validation_dataset = dataset["__valid__"]
+
+    if "__extra_valid__" in dataset:
+        if not isinstance(dataset["__extra_valid__"], list):
+            raise RuntimeError(
+                f"If present, dataset['__extra_valid__'] must be a list, "
+                f"but you passed a {type(dataset['__extra_valid__'])}, "
+                f"which is invalid."
+            )
+        logger.info(
+            f"Found {len(dataset['__extra_valid__'])} extra validation "
+            f"set(s) to be tracked during training"
+        )
+        logger.info(
+            "Extra validation sets are NOT used for model checkpointing!"
+        )
+        extra_validation_datasets = dataset["__extra_valid__"]
+    else:
+        extra_validation_datasets = None
+
+    predict_dataset = dataset
+
+    return (
+        train_dataset,
+        validation_dataset,
+        extra_validation_datasets,
+        predict_dataset,
+    )
diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py
index 60a004a3..a854e559 100644
--- a/src/ptbench/data/shenzhen/__init__.py
+++ b/src/ptbench/data/shenzhen/__init__.py
@@ -22,13 +22,11 @@ the daily routine using Philips DR Digital Diagnose systems.
 
 import importlib.resources
 import os
-import random
-
-import torch
 
 from clapper.logging import setup
 
 from ...utils.rc import load_rc
+from .. import make_dataset
 from ..dataset import JSONDataset
 from ..loader import load_pil_baw, make_delayed
 
@@ -93,339 +91,3 @@ def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
         [ElasticDeformation(p=0.8)],
         post_transforms,
     )
-
-
-def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
-    """Creates a new data set, applying transforms.
-
-    .. note::
-
-       This is a convenience function for our own dataset definitions inside
-       this module, guaranteeting homogenity between dataset definitions
-       provided in this package.  It assumes certain strategies for data
-       augmentation that may not be translatable to other applications.
-
-
-    Parameters
-    ----------
-
-    samples : list
-        List of delayed samples
-
-    transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-
-    prefixes : list
-        A list of data augmentation operations that needs to be applied
-        **before** the transforms above
-
-    suffixes : list
-        A list of data augmentation operations that needs to be applied
-        **after** the transforms above
-
-
-    Returns
-    -------
-
-    subset : :py:class:`ptbench.data.utils.SampleListDataset`
-        A pre-formatted dataset that can be fed to one of our engines
-    """
-    from ...data.utils import SampleListDataset as wrapper
-
-    return wrapper(samples, prefixes + transforms + suffixes)
-
-
-def make_dataset(
-    subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
-):
-    """Creates a new configuration dataset from a list of dictionaries and
-    transforms.
-
-    This function takes as input a list of dictionaries as those that can be
-    returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
-    mapping protocol names (such as ``train``, ``dev`` and ``test``) to
-    :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
-    transforms, and returns a dictionary applying
-    :py:class:`ptbench.data.utils.SampleListDataset` to these
-    lists, and our standard data augmentation if a ``train`` set exists.
-
-    For example, if ``subsets`` is composed of two sets named ``train`` and
-    ``test``, this function will yield a dictionary with the following entries:
-
-    * ``__train__``: Wraps the ``train`` subset, includes data augmentation
-      (note: datasets with names starting with ``_`` (underscore) are excluded
-      from prediction and evaluation by default, as they contain data
-      augmentation transformations.)
-    * ``train``: Wraps the ``train`` subset, **without** data augmentation
-    * ``test``: Wraps the ``test`` subset, **without** data augmentation
-
-    .. note::
-
-       This is a convenience function for our own dataset definitions inside
-       this module, guaranteeting homogenity between dataset definitions
-       provided in this package.  It assumes certain strategies for data
-       augmentation that may not be translatable to other applications.
-
-
-    Parameters
-    ----------
-
-    subsets : list
-        A list of dictionaries that contains the delayed sample lists
-        for a number of named lists. The subsets will be aggregated in one
-        final subset. If one of the keys is ``train``, our standard dataset
-        augmentation transforms are appended to the definition of that subset.
-        All other subsets remain un-augmented.
-
-    transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-
-    t_transforms : list
-        A list of transforms that needs to be applied to the train samples
-
-    post_transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-        after all the other transforms
-
-
-    Returns
-    -------
-
-    dataset : dict
-        A pre-formatted dataset that can be fed to one of our engines. It maps
-        string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
-    """
-
-    retval = {}
-
-    if len(subsets_groups) == 1:
-        subsets = subsets_groups[0]
-    else:
-        # If multiple subsets groups: aggregation
-        aggregated_subsets = {}
-        for subsets in subsets_groups:
-            for key in subsets.keys():
-                if key in aggregated_subsets:
-                    aggregated_subsets[key] += subsets[key]
-                    # Shuffle if data comes from multiple datasets
-                    random.shuffle(aggregated_subsets[key])
-                else:
-                    aggregated_subsets[key] = subsets[key]
-        subsets = aggregated_subsets
-
-    # Add post_transforms after t_transforms for the train set
-    t_transforms += post_transforms
-
-    for key in subsets.keys():
-        retval[key] = make_subset(
-            subsets[key], transforms=transforms, suffixes=post_transforms
-        )
-        if key == "train":
-            retval["__train__"] = make_subset(
-                subsets[key], transforms=transforms, suffixes=(t_transforms)
-            )
-        if key == "validation":
-            # also use it for validation during training
-            retval["__valid__"] = retval[key]
-
-    if (
-        ("__train__" in retval)
-        and ("train" in retval)
-        and ("__valid__" not in retval)
-    ):
-        # if the dataset does not have a validation set, we use the unaugmented
-        # training set as validation set
-        retval["__valid__"] = retval["train"]
-
-    return retval
-
-
-def get_samples_weights(dataset):
-    """Compute the weights of all the samples of the dataset to balance it
-    using the sampler of the dataloader.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the weights to balance each class in the dataset and the
-    datasets themselves if we have a ConcatDataset.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    samples_weights : :py:class:`torch.Tensor`
-        the weights for all the samples in the dataset given as input
-    """
-    samples_weights = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            # Weighting only for binary labels
-            if isinstance(ds._samples[0].label, int):
-                # Groundtruth
-                targets = []
-                for s in ds._samples:
-                    targets.append(s.label)
-                targets = torch.tensor(targets)
-
-                # Count number of samples per class
-                class_sample_count = torch.tensor(
-                    [
-                        (targets == t).sum()
-                        for t in torch.unique(targets, sorted=True)
-                    ]
-                )
-
-                weight = 1.0 / class_sample_count.float()
-
-                samples_weights.append(
-                    torch.tensor([weight[t] for t in targets])
-                )
-
-            else:
-                # We only weight to sample equally from each dataset
-                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
-
-        # Concatenate sample weights from all the datasets
-        samples_weights = torch.cat(samples_weights)
-
-    else:
-        # Weighting only for binary labels
-        if isinstance(dataset._samples[0].label, int):
-            # Groundtruth
-            targets = []
-            for s in dataset._samples:
-                targets.append(s.label)
-            targets = torch.tensor(targets)
-
-            # Count number of samples per class
-            class_sample_count = torch.tensor(
-                [
-                    (targets == t).sum()
-                    for t in torch.unique(targets, sorted=True)
-                ]
-            )
-
-            weight = 1.0 / class_sample_count.float()
-
-            samples_weights = torch.tensor([weight[t] for t in targets])
-
-        else:
-            # Equal weights for non-binary labels
-            samples_weights = torch.ones(len(dataset._samples))
-
-    return samples_weights
-
-
-def get_positive_weights(dataset):
-    """Compute the positive weights of each class of the dataset to balance the
-    BCEWithLogitsLoss criterion.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the positive weights of each class to use them to have
-    a balanced loss.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    positive_weights : :py:class:`torch.Tensor`
-        the positive weight of each class in the dataset given as input
-    """
-    targets = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            for s in ds._samples:
-                targets.append(s.label)
-
-    else:
-        for s in dataset._samples:
-            targets.append(s.label)
-
-    targets = torch.tensor(targets)
-
-    # Binary labels
-    if len(list(targets.shape)) == 1:
-        class_sample_count = [
-            float((targets == t).sum().item())
-            for t in torch.unique(targets, sorted=True)
-        ]
-
-        # Divide negatives by positives
-        positive_weights = torch.tensor(
-            [class_sample_count[0] / class_sample_count[1]]
-        ).reshape(-1)
-
-    # Multiclass labels
-    else:
-        class_sample_count = torch.sum(targets, dim=0)
-        negative_class_sample_count = (
-            torch.full((targets.size()[1],), float(targets.size()[0]))
-            - class_sample_count
-        )
-
-        positive_weights = negative_class_sample_count / (
-            class_sample_count + negative_class_sample_count
-        )
-
-    return positive_weights
-
-
-def return_subsets(dataset):
-    train_dataset = None
-    validation_dataset = None
-    extra_validation_datasets = None
-    predict_dataset = None
-
-    if "__train__" in dataset:
-        logger.info("Found (dedicated) '__train__' set for training")
-        train_dataset = dataset["__train__"]
-    else:
-        train_dataset = dataset["train"]
-
-    if "__valid__" in dataset:
-        logger.info("Found (dedicated) '__valid__' set for validation")
-        validation_dataset = dataset["__valid__"]
-
-    if "__extra_valid__" in dataset:
-        if not isinstance(dataset["__extra_valid__"], list):
-            raise RuntimeError(
-                f"If present, dataset['__extra_valid__'] must be a list, "
-                f"but you passed a {type(dataset['__extra_valid__'])}, "
-                f"which is invalid."
-            )
-        logger.info(
-            f"Found {len(dataset['__extra_valid__'])} extra validation "
-            f"set(s) to be tracked during training"
-        )
-        logger.info(
-            "Extra validation sets are NOT used for model checkpointing!"
-        )
-        extra_validation_datasets = dataset["__extra_valid__"]
-    else:
-        extra_validation_datasets = None
-
-    predict_dataset = dataset
-
-    return (
-        train_dataset,
-        validation_dataset,
-        extra_validation_datasets,
-        predict_dataset,
-    )
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index 3801efbc..bbeabcaf 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -12,8 +12,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_0.py b/src/ptbench/data/shenzhen/fold_0.py
index d65d513b..5b4d4560 100644
--- a/src/ptbench/data/shenzhen/fold_0.py
+++ b/src/ptbench/data/shenzhen/fold_0.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_0_rgb.py b/src/ptbench/data/shenzhen/fold_0_rgb.py
index bcc853dd..143ef731 100644
--- a/src/ptbench/data/shenzhen/fold_0_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_0_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_1.py b/src/ptbench/data/shenzhen/fold_1.py
index b9494f15..f01adef0 100644
--- a/src/ptbench/data/shenzhen/fold_1.py
+++ b/src/ptbench/data/shenzhen/fold_1.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_1_rgb.py b/src/ptbench/data/shenzhen/fold_1_rgb.py
index 01e23967..9d457adf 100644
--- a/src/ptbench/data/shenzhen/fold_1_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_1_rgb.py
@@ -12,8 +12,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_2.py b/src/ptbench/data/shenzhen/fold_2.py
index 8d5cf816..04dd6562 100644
--- a/src/ptbench/data/shenzhen/fold_2.py
+++ b/src/ptbench/data/shenzhen/fold_2.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_2_rgb.py b/src/ptbench/data/shenzhen/fold_2_rgb.py
index baf6752d..37cbe10e 100644
--- a/src/ptbench/data/shenzhen/fold_2_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_2_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_3.py b/src/ptbench/data/shenzhen/fold_3.py
index b42882cc..b43fcb29 100644
--- a/src/ptbench/data/shenzhen/fold_3.py
+++ b/src/ptbench/data/shenzhen/fold_3.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_3_rgb.py b/src/ptbench/data/shenzhen/fold_3_rgb.py
index a02c2b1d..162a3f82 100644
--- a/src/ptbench/data/shenzhen/fold_3_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_3_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_4.py b/src/ptbench/data/shenzhen/fold_4.py
index a9ad1471..58e0a2f2 100644
--- a/src/ptbench/data/shenzhen/fold_4.py
+++ b/src/ptbench/data/shenzhen/fold_4.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_4_rgb.py b/src/ptbench/data/shenzhen/fold_4_rgb.py
index 3620ba42..0dd4ccf8 100644
--- a/src/ptbench/data/shenzhen/fold_4_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_4_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_5.py b/src/ptbench/data/shenzhen/fold_5.py
index 426a9d66..ff115340 100644
--- a/src/ptbench/data/shenzhen/fold_5.py
+++ b/src/ptbench/data/shenzhen/fold_5.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_5_rgb.py b/src/ptbench/data/shenzhen/fold_5_rgb.py
index 29e70138..46e255e7 100644
--- a/src/ptbench/data/shenzhen/fold_5_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_5_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_6.py b/src/ptbench/data/shenzhen/fold_6.py
index fb0a91b8..eb81ae88 100644
--- a/src/ptbench/data/shenzhen/fold_6.py
+++ b/src/ptbench/data/shenzhen/fold_6.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_6_rgb.py b/src/ptbench/data/shenzhen/fold_6_rgb.py
index 35e7e6d7..b9654d08 100644
--- a/src/ptbench/data/shenzhen/fold_6_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_6_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_7.py b/src/ptbench/data/shenzhen/fold_7.py
index 743b344d..79b0d1ff 100644
--- a/src/ptbench/data/shenzhen/fold_7.py
+++ b/src/ptbench/data/shenzhen/fold_7.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_7_rgb.py b/src/ptbench/data/shenzhen/fold_7_rgb.py
index 0a9f83d7..8a36acb2 100644
--- a/src/ptbench/data/shenzhen/fold_7_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_7_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_8.py b/src/ptbench/data/shenzhen/fold_8.py
index ee7f7167..cf1cd36a 100644
--- a/src/ptbench/data/shenzhen/fold_8.py
+++ b/src/ptbench/data/shenzhen/fold_8.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_8_rgb.py b/src/ptbench/data/shenzhen/fold_8_rgb.py
index 1351790f..1aa0bcec 100644
--- a/src/ptbench/data/shenzhen/fold_8_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_8_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_9.py b/src/ptbench/data/shenzhen/fold_9.py
index dbd8ab31..e1bb569d 100644
--- a/src/ptbench/data/shenzhen/fold_9.py
+++ b/src/ptbench/data/shenzhen/fold_9.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/fold_9_rgb.py b/src/ptbench/data/shenzhen/fold_9_rgb.py
index 729141dc..c0a577df 100644
--- a/src/ptbench/data/shenzhen/fold_9_rgb.py
+++ b/src/ptbench/data/shenzhen/fold_9_rgb.py
@@ -11,8 +11,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py
index 6186990e..7bdb8fe3 100644
--- a/src/ptbench/data/shenzhen/rgb.py
+++ b/src/ptbench/data/shenzhen/rgb.py
@@ -12,8 +12,9 @@
 
 from clapper.logging import setup
 
+from .. import return_subsets
 from ..base_datamodule import BaseDataModule
-from . import _maker, return_subsets
+from . import _maker
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
-- 
GitLab