From a67626d86cb4238f118e6da7341d87fc28c540a9 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 28 Jun 2023 22:41:05 +0200
Subject: [PATCH] [datamodule] Slightly streamlines the datamodule approach;
 adds documentation; adds type annotations; adds TODOs

---
 src/ptbench/data/base_datamodule.py   | 196 ----------
 src/ptbench/data/datamodule.py        | 503 ++++++++++++++++++++++++
 src/ptbench/data/dataset.py           | 526 ++++++++------------------
 src/ptbench/data/shenzhen/__init__.py |  19 -
 src/ptbench/data/shenzhen/default.py  |  43 ++-
 src/ptbench/data/shenzhen/loader.py   |  68 ++++
 src/ptbench/data/shenzhen/utils.py    | 114 ------
 src/ptbench/data/transforms.py        |   2 +
 8 files changed, 765 insertions(+), 706 deletions(-)
 delete mode 100644 src/ptbench/data/base_datamodule.py
 create mode 100644 src/ptbench/data/datamodule.py
 create mode 100644 src/ptbench/data/shenzhen/loader.py
 delete mode 100644 src/ptbench/data/shenzhen/utils.py

diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
deleted file mode 100644
index 8377c663..00000000
--- a/src/ptbench/data/base_datamodule.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-import multiprocessing
-import sys
-
-import lightning as pl
-import torch
-
-from clapper.logging import setup
-from torch.utils.data import DataLoader, WeightedRandomSampler
-from torchvision import transforms
-
-from .dataset import get_samples_weights
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class BaseDataModule(pl.LightningDataModule):
-    def __init__(
-        self,
-        batch_size=1,
-        batch_chunk_count=1,
-        drop_incomplete_batch=False,
-        parallel=-1,
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.batch_chunk_count = batch_chunk_count
-
-        self.train_dataset = None
-        self.validation_dataset = None
-        self.extra_validation_datasets = None
-
-        self._raw_data_transforms = []
-        self._model_transforms = []
-        self._augmentation_transforms = []
-
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.pin_memory = (
-            torch.cuda.is_available()
-        )  # should only be true if GPU available and using it
-
-        self.parallel = parallel
-
-    def setup(self, stage: str):
-        # Implemented by user
-        # Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
-        raise NotImplementedError
-
-    def train_dataloader(self):
-        train_samples_weights = get_samples_weights(self.train_dataset)
-
-        multiproc_kwargs = self._setup_multiproc(self.parallel)
-        train_sampler = WeightedRandomSampler(
-            train_samples_weights, len(train_samples_weights), replacement=True
-        )
-
-        return DataLoader(
-            self.train_dataset,
-            batch_size=self._compute_chunk_size(
-                self.batch_size, self.batch_chunk_count
-            ),
-            drop_last=self.drop_incomplete_batch,
-            pin_memory=self.pin_memory,
-            sampler=train_sampler,
-            **multiproc_kwargs,
-        )
-
-    def val_dataloader(self):
-        loaders_dict = {}
-
-        multiproc_kwargs = self._setup_multiproc(self.parallel)
-
-        val_loader = DataLoader(
-            dataset=self.validation_dataset,
-            batch_size=self._compute_chunk_size(
-                self.batch_size, self.batch_chunk_count
-            ),
-            shuffle=False,
-            drop_last=False,
-            pin_memory=self.pin_memory,
-            **multiproc_kwargs,
-        )
-
-        loaders_dict["validation_loader"] = val_loader
-
-        if self.extra_validation_datasets is not None:
-            for set_idx, extra_set in enumerate(self.extra_validation_datasets):
-                extra_val_loader = DataLoader(
-                    dataset=extra_set,
-                    batch_size=self._compute_chunk_size(
-                        self.batch_size, self.batch_chunk_count
-                    ),
-                    shuffle=False,
-                    drop_last=False,
-                    pin_memory=self.pin_memory,
-                    **multiproc_kwargs,
-                )
-
-                loaders_dict[
-                    f"extra_validation_loader{set_idx}"
-                ] = extra_val_loader
-
-        return loaders_dict
-
-    def predict_dataloader(self):
-        loaders_dict = {}
-
-        loaders_dict["train_loader"] = self.train_dataloader()
-        for k, v in self.val_dataloader().items():
-            loaders_dict[k] = v
-
-        return loaders_dict
-
-    def update_module_properties(self, **kwargs):
-        for k, v in kwargs.items():
-            setattr(self, k, v)
-
-    def _compute_chunk_size(self, batch_size, chunk_count):
-        batch_chunk_size = batch_size
-        if batch_size % chunk_count != 0:
-            # batch_size must be divisible by batch_chunk_count.
-            raise RuntimeError(
-                f"--batch-size ({batch_size}) must be divisible by "
-                f"--batch-chunk-size ({chunk_count})."
-            )
-        else:
-            batch_chunk_size = batch_size // chunk_count
-
-        return batch_chunk_size
-
-    def _setup_multiproc(self, parallel):
-        multiproc_kwargs = dict()
-        if parallel < 0:
-            multiproc_kwargs["num_workers"] = 0
-        else:
-            multiproc_kwargs["num_workers"] = (
-                parallel or multiprocessing.cpu_count()
-            )
-
-        if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
-            multiproc_kwargs[
-                "multiprocessing_context"
-            ] = multiprocessing.get_context("spawn")
-
-        return multiproc_kwargs
-
-    def _build_transforms(self, is_train):
-        all_transforms = self.raw_data_transforms + self.model_transforms
-
-        if is_train:
-            all_transforms = all_transforms + self.augmentation_transforms
-
-        # all_transforms.append(transforms.ToTensor())
-
-        return transforms.Compose(all_transforms)
-
-    @property
-    def raw_data_transforms(self):
-        return self._raw_data_transforms
-
-    @raw_data_transforms.setter
-    def raw_data_transforms(self, transforms):
-        self._raw_data_transforms = transforms
-
-    @property
-    def model_transforms(self):
-        return self._model_transforms
-
-    @model_transforms.setter
-    def model_transforms(self, transforms):
-        self._model_transforms = transforms
-
-    @property
-    def augmentation_transforms(self):
-        return self._augmentation_transforms
-
-    @augmentation_transforms.setter
-    def augmentation_transforms(self, transforms):
-        self._augmentation_transforms = transforms
-
-
-def get_dataset_from_module(module, stage, **module_args):
-    """Instantiates a DataModule and retrieves the corresponding dataset.
-
-    Useful when combining multiple datasets.
-    """
-    module_instance = module(**module_args)
-    module_instance.prepare_data()
-    module_instance.setup(stage=stage)
-    dataset = module_instance.dataset
-
-    return dataset
diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
new file mode 100644
index 00000000..11e76697
--- /dev/null
+++ b/src/ptbench/data/datamodule.py
@@ -0,0 +1,503 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import collections
+import multiprocessing
+import sys
+import typing
+
+import lightning
+import torch
+import torch.utils.data
+
+from clapper.logging import setup
+
+# TODO: No logging on this module...
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+def _setup_dataloader_multiproc_parameters(
+    parallel: int,
+) -> dict[str, typing.Any]:
+    """Returns a dictionary containing pytorch arguments to be used in data
+    loaders.
+
+    It sets the parameter ``num_workers`` to match the expected pytorch
+    representation.  For macOS machines, it also sets the
+    ``multiprocessing_context`` to use ``spawn`` instead of the default.
+
+    The mapping between the command-line interface ``parallel`` setting works
+    like this:
+
+    .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
+       :widths: 15 15 70
+       :header-rows: 1
+
+       * - CLI ``parallel``
+         - :py:class:`torch.utils.data.DataLoader` ``kwargs``
+         - Comments
+       * - ``<0``
+         - 0
+         - Disables multiprocessing entirely, executes everything within the
+           same processing context
+       * - ``0``
+         - :py:func:`multiprocessing.cpu_count`
+         - Runs mini-batch data loading on as many external processes as CPUs
+           available in the current machine
+       * - ``>=1``
+         - ``parallel``
+         - Runs mini-batch data loading on as many external processes as set on
+           ``parallel``
+    """
+
+    retval: dict[str, typing.Any] = dict()
+    if parallel < 0:
+        retval["num_workers"] = 0
+    else:
+        retval["num_workers"] = parallel or multiprocessing.cpu_count()
+
+    if retval["num_workers"] > 0 and sys.platform == "darwin":
+        retval["multiprocessing_context"] = multiprocessing.get_context("spawn")
+
+    return retval
+
+
+class _DelayedLoadingDataset(torch.utils.data.Dataset):
+    """A list that loads its samples on demand.
+
+    This list mimics a pytorch Dataset, except raw data loading is done
+    on-the-fly, as the samples are requested through the bracket operator.
+
+
+    Parameters
+    ----------
+
+    split
+        An iterable containing the raw dataset samples loaded from the database
+        splits.
+
+    raw_data_loader
+        A callable that will take the representation of the samples declared
+        inside database splits, load the actual raw data (if one exists), and
+        transform it into a :py:class:`torch.Tensor`containing floats in the
+        interval ``[0, 1]``, and a dictionary with further metadata attributes.
+
+    transforms
+        A set of transforms that should be applied on-the-fly for this dataset.
+    """
+
+    def __init__(
+        self,
+        split: typing.Sequence[typing.Any],
+        raw_data_loader: typing.Callable[
+            [typing.Any], tuple[torch.Tensor, typing.Mapping]
+        ],
+        transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None,
+    ):
+        self.split = split
+        self.raw_data_loader = raw_data_loader
+        self.transform = torch.nn.Sequential(*transforms)
+
+    def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
+        tensor, metadata = self.raw_data_loader(self.split[key])
+        return self.transform(tensor), metadata
+
+    def __len__(self):
+        return len(self.split)
+
+
+class _CachedDataset(torch.utils.data.Dataset):
+    """Basically, a list of preloaded samples.
+
+    This dataset will load all samples from the split during construction
+    instead of delaying that to the indexing.
+
+
+    Parameters
+    ----------
+
+    split
+        An iterable containing the raw dataset samples loaded from the database
+        splits.
+
+    raw_data_loader
+        A callable that will take the representation of the samples declared
+        inside database splits, load the actual raw data (if one exists), and
+        transform it into a :py:class:`torch.Tensor`containing floats in the
+        interval ``[0, 1]``, and a dictionary with further metadata attributes.
+
+    transforms
+        A set of transforms that should be applied on-the-fly for this dataset.
+    """
+
+    def __init__(
+        self,
+        split: typing.Sequence[typing.Any],
+        raw_data_loader: typing.Callable[
+            [typing.Any], tuple[torch.Tensor, typing.Mapping]
+        ],
+        transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None,
+    ):
+        self.data = [raw_data_loader(k) for k in split]
+        self.transform = torch.nn.Sequential(*transforms)
+
+    def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
+        tensor, metadata = self.data[key]
+        return self.transform(tensor), metadata
+
+    def __len__(self):
+        return len(self.data)
+
+
+def get_sample_weights(
+    dataset: _DelayedLoadingDataset | _CachedDataset,
+) -> torch.Tensor:
+    """Computes the (inverse) probabilities of samples based on their class.
+
+    This function takes as input a torch Dataset, and computes the weights to
+    balance each class in the dataset, and the datasets themselves if one
+    passes a :py:class:`torch.utils.data.ConcatDataset`.
+
+    This function assumes labels are stored on the second entry of sample and
+    are integers. If that is not the case, we just balance the overall dataset
+    w.r.t. each other (e.g. with a :py:class:`torch.utils.data.ConcatDataset`).
+    For single dataset (not a concatenated one) without labels this function is
+    the same as a no-op.
+
+
+    Parameters
+    ----------
+
+    dataset
+        An instance of torch Dataset.
+        :py:class:`torch.utils.data.ConcatDataset` are supported
+
+
+    Returns
+    -------
+
+    sample_weights
+        The weights for all the samples in the dataset given as input
+    """
+    retval = []
+
+    def _calculate_dataset_weights(
+        d: _DelayedLoadingDataset | _CachedDataset,
+    ) -> torch.Tensor:
+        """Calculates the weights for one dataset."""
+        if "label" in d[0][1] and isinstance(d[0][1]["label"], int):
+            # there are labels
+            targets = [k[1]["label"] for k in d]
+            counts = collections.Counter(targets)
+            weights = {k: 1.0 / v for k, v in counts.items()}
+            weight_per_sample = [weights[k[1]] for k in d]
+            return torch.tensor(weight_per_sample)
+        else:
+            # no labels, weight dataset samples only by count
+            # n.b.: only works if __len__(d) is implemented, we turn off typechecking
+            weight = 1.0 / len(d)
+            return torch.tensor(len(d) * [weight])
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            retval.append(_calculate_dataset_weights(ds))  # type: ignore[arg-type]
+
+        # Concatenate sample weights from all the datasets
+        return torch.cat(retval)
+
+    return _calculate_dataset_weights(dataset)
+
+
+class CachingDataModule(lightning.LightningDataModule):
+    """A conveninent data module with CSV or JSON protocol loading, mini-
+    batching, parallelisation and caching, all in one.
+
+    Instances of this class load data-split (a.k.a. protocol) definitions for a
+    database, and can load the data from the disk.  An optional caching
+    mechanism stores the data at associated CPU memory, which can improve data
+    serving while training and evaluating models.
+
+    This datamodule defines basic operations to handle data loading and
+    mini-batch handling within this package's framework.  It can return
+    :py:class:`torch.utils.data.DataLoader` objects for training, validation,
+    prediction and testing conditions.  Parallelisation is handled by a simple
+    input flag.
+
+    Users must implement the basic :py:meth:`setup` function, which is
+    parameterised by a single string enumeration containing: ``fit``,
+    ``validate``, ``test``, or ``predict``.
+
+
+    Parameters
+    ----------
+
+    database_split
+        A dictionary that contains string keys representing subset names, and
+        values that are iterables over sample representations (potentially on
+        disk).  These objects are passed to the ``raw_data_loader`` for loading
+        the sample data (and metadata) in memory.  The objects represented may
+        be of any format (e.g. list, dictionary, etc), for as long as the
+        ``raw_data_loader`` can properly handle it.  To check the split and the
+        loader function works correctly, you may use
+        :py:func:`..dataset.check_database_split_loading`.  As is, this class
+        expects at least one entry called ``train`` to exist in the input
+        dictionary.  Optional entries are ``validation``, and ``test``. Entries
+        named ``monitor-...`` will be considered extra subsets that do not
+        influence any early stop criteria during training, and are just
+        monitored beyond the ``validation`` dataset.
+
+    raw_data_loader
+        A callable that will take the representation of the samples declared
+        inside database splits, load the actual raw data (if one exists), and
+        transform it into a :py:class:`torch.Tensor`containing floats in the
+        interval ``[0, 1]``, and a dictionary with further metadata attributes.
+        Samples can be cached **after** raw data loading.
+
+    cache_samples
+        If set, then issue raw data loading during ``prepare_data()``, and
+        serves samples from CPU memory.  Otherwise, loads samples from disk on
+        demand. Running from CPU memory will offer increased speeds in exchange
+        for CPU memory.  Sufficient CPU memory must be available before you set
+        this attribute to ``True``.  It is typicall useful for relatively small
+        datasets.
+
+    train_sampler
+        If set, then reset the default random sampler from torch with a
+        (potentially) biased one of your choice.  Notice this will not play an
+        effect on the batching strategy, only on the way the samples are picked
+        from the original dataset.  A typical replacement is:
+
+        .. code:: python
+
+           sample_weights = get_sample_weights(dataset)
+           train_sampler = torch.utils.data.WeightedRandomSampler(
+                   sample_weights, len(sample_weights), replacement=True
+                   )
+
+        The variable ``sample_weights`` represents the probabilities (may not
+        sum to one) of picking each particular sample in the dataset for a
+        batch.  We typically set ``replacement=True`` to avoid issues with
+        missing data from one of the classes in mini-batches.  This is
+        particularly important in highly unbalanced datasets.  The function
+        :py:func:`get_sample_weights` may help you in this aspect.
+
+    data_augmentations
+        A list of torchvision transforms (torch modules) that will be applied
+        on training set samples to create data augmentations during the
+        training of a model. Augmentation transform pipelines are applied
+        *after* the raw data is loaded, and before ``model_transforms``.
+        Augmentation transforms assume they receive a torch tensor representing
+        an image as input (see :py:class:`torchvision.transforms.ToTensor` for
+        details), in the range ``[0, 1]``.
+
+    model_transforms
+        A list of torchvision transforms (torch modules) that will be applied
+        after data augmentation transforms, and just before data is fed into
+        the model for all data loaders produced by this data module.  This part
+        of the pipeline receives data as output by the raw-data-loader, or from
+        data augmentations, if any is specified.
+
+    batch_size
+        Number of samples in every **training** batch (this parameter affects
+        memory requirements for the network).  If the number of samples in the
+        batch is larger than the total number of samples available for
+        training, this value is truncated.  If this number is smaller, then
+        batches of the specified size are created and fed to the network  until
+        there are no more new samples to feed (epoch is finished).  If the
+        total number of training samples is not a multiple of the batch-size,
+        the last batch will be smaller than the first, unless
+        ``drop_incomplete_batch`` is set to ``true``, in which case this batch
+        is not used.
+
+    batch_chunk_count
+        Number of chunks in every batch (this parameter affects memory
+        requirements for the network). The number of samples loaded for every
+        iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
+        needs to be divisible by ``batch_chunk_count``, otherwise an error will
+        be raised. This parameter is used to reduce number of samples loaded in
+        each iteration, in order to reduce the memory usage in exchange for
+        processing time (more iterations). This is specially interesting whe
+        one is running with GPUs with limited RAM. The default of 1 forces the
+        whole batch to be processed at once. Otherwise the batch is broken into
+        batch-chunk-count pieces, and gradients are accumulated to complete
+        each batch.
+
+    drop_incomplete_batch
+        If set, then may drop the last batch in an epoch, in case it is
+        incomplete.  If you set this option, you should also consider
+        increasing the total number of epochs of training, as the total number
+        of training steps may be reduced.
+
+    parallel
+        Use multiprocessing for data loading: if set to -1 (default), disables
+        multiprocessing data loading.  Set to 0 to enable as many data loading
+        instances as processing cores as available in the system.  Set to >= 1
+        to enable that many multiprocessing instances for data loading.
+    """
+
+    def __init__(
+        self,
+        database_split: dict[str, typing.Sequence[typing.Any]],
+        raw_data_loader: typing.Callable[
+            [typing.Any], tuple[torch.Tensor, typing.Mapping]
+        ],
+        cache_samples: bool = False,
+        train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
+        data_augmentations: list[torch.nn.Module] = [],
+        model_transforms: list[torch.nn.Module] = [],
+        batch_size: int = 1,
+        batch_chunk_count: int = 1,
+        drop_incomplete_batch: bool = False,
+        parallel: int = -1,
+    ):
+        # validation
+        if batch_size % batch_chunk_count != 0:
+            raise RuntimeError(
+                f"batch_size ({batch_size}) must be divisible by "
+                f"batch_chunk_size ({batch_chunk_count})."
+            )
+
+        super().__init__()
+
+        self.database_split = database_split
+        self.raw_data_loader = raw_data_loader
+        self.cache_samples = cache_samples
+        self.train_sampler = train_sampler
+        self.data_augmentations = data_augmentations
+        self.model_transforms = model_transforms
+
+        self._batch_size = batch_size
+        self._batch_chunk_count = batch_chunk_count
+        self._chunk_size = self._batch_size // self._batch_chunk_count
+
+        self.drop_incomplete_batch = drop_incomplete_batch
+        self._parallel = parallel  # immutable, otherwise would need to call
+        # the next function again
+        self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
+            parallel
+        )
+
+        self.pin_memory = (
+            torch.cuda.is_available()
+        )  # should only be true if GPU available and using it
+
+        # datasets that have been setup() for the current stage
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
+
+    def setup(self, stage: str) -> None:
+        """Sets up datasets for different tasks on the pipeline.
+
+        This method should setup (load, pre-process, etc) all subsets required
+        for a particular ``stage`` (fit, validate, test, predict), and keep
+        them ready to be used on one of the `_dataloader()` functions that are
+        pertinent for such stage.
+
+        If you have set ``cache_samples``, samples are loaded at this stage and
+        cached in memory.
+
+
+        Parameters
+        ----------
+
+        stage
+            Name of the stage to which the setup is applicable.  Can be one of
+            ``fit``, ``validate``, ``test`` or ``predict``.  Each stage
+            typically uses the following data loaders:
+
+            * ``fit``: uses both train and validation datasets
+            * ``validate``: uses only the validation dataset
+            * ``test``: uses only the test dataset
+            * ``predict``: uses only the test dataset
+        """
+
+        def _setup(name, transforms):
+            if self.cache_samples:
+                self._datasets[name] = _CachedDataset(
+                    self.database_split[name], self.raw_data_loader, transforms
+                )
+            else:
+                self._datasets[name] = _DelayedLoadingDataset(
+                    self.database_split[name], self.raw_data_loader, transforms
+                )
+
+        if stage == "fit":
+            _setup("train", self.data_augmentations + self.model_transforms)
+            _setup("validation", self.model_transforms)
+            for k in self.database_split:
+                if k.startswith("monitor-"):
+                    _setup(k, self.model_transforms)
+
+        elif stage == "validate":
+            _setup("validation", self.model_transforms)
+            for k in self.database_split:
+                if k.startswith("monitor-"):
+                    _setup(k, self.model_transforms)
+
+        elif stage == "test":
+            _setup("test", self.model_transforms)
+
+        elif stage == "predict":
+            _setup("test", self.model_transforms)
+
+    def train_dataloader(self):
+        """Returns the train data loader."""
+
+        return torch.utils.data.DataLoader(
+            self._datasets["train"],
+            batch_size=self._chunk_size,
+            drop_last=self.drop_incomplete_batch,
+            pin_memory=self.pin_memory,
+            sampler=self.train_sampler,
+            **self._dataloader_multiproc,
+        )
+
+    def val_dataloader(self):
+        """Returns the validation data loader(s)"""
+
+        extra_valid = [
+            k for k in self.database_split.keys() if k.startswith("monitor-")
+        ]
+
+        # TODO: do we really need the train sampler here?
+        validation_loader_opts = {
+            "batch_size": self._chunk_size,
+            "shuffle": False,
+            "drop_last": self.drop_incomplete_batch,
+            "pin_memory": self.pin_memory,
+        }
+        validation_loader_opts.update(self._dataloader_multiproc)
+
+        # TODO: not sure this is the right way to handle multiple validation
+        # loaders, please check and fix
+        if not extra_valid:
+            return torch.utils.data.DataLoader(
+                self._datasets["validation"],
+                **validation_loader_opts,
+            )
+
+        else:
+            return [
+                torch.utils.data.DataLoader(
+                    self._datasets[k],
+                    **validation_loader_opts,
+                )
+                for k in ["validation"] + extra_valid
+            ]
+
+    def test_dataloader(self):
+        """Returns the test data loader(s)"""
+
+        return torch.utils.data.DataLoader(
+            self._datasets["test"],
+            batch_size=self._chunk_size,
+            shuffle=False,
+            drop_last=self.drop_incomplete_batch,
+            pin_memory=self.pin_memory,
+            **self._dataloader_multiproc,
+        )
+
+    def predict_dataloader(self):
+        """Returns the prediction data loader(s)"""
+
+        return self.test_dataloader()
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index 15dc32a9..6b373db4 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -2,453 +2,261 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import collections.abc
 import csv
+import importlib.abc
 import json
 import logging
-import os
 import pathlib
+import typing
 
 import torch
-
-from tqdm import tqdm
+import torch.utils.data
 
 logger = logging.getLogger(__name__)
 
 
-class JSONProtocol:
-    """Generic multi-protocol/subset filelist dataset that yields samples.
+class JSONDatabaseSplit(
+    dict,
+    typing.Mapping[str, typing.Sequence[typing.Any]],
+):
+    """Defines a loader that understands a database split (train, test, etc) in
+    JSON format.
 
-    To create a new dataset, you need to provide one or more JSON formatted
-    filelists (one per protocol) with the following contents:
+    To create a new database split, you need to provide a JSON formatted
+    dictionary in a file, with contents similar to the following:
 
     .. code-block:: json
 
        {
            "subset1": [
                [
-                   "value1",
-                   "value2",
-                   "value3"
+                   "sample1-data1",
+                   "sample1-data2",
+                   "sample1-data3",
                ],
                [
-                   "value4",
-                   "value5",
-                   "value6"
+                   "sample2-data1",
+                   "sample2-data2",
+                   "sample2-data3",
                ]
            ],
            "subset2": [
+               [
+                   "sample42-data1",
+                   "sample42-data2",
+                   "sample42-data3",
+               ],
            ]
        }
 
-    Your dataset many contain any number of subsets, but all sample entries
-    must contain the same number of fields.
+    Your database split many contain any number of subsets (dictionary keys).
+    For simplicity, we recommend all sample entries are formatted similarly so
+    that raw-data-loading is simplified.  Use the function
+    :py:func:`check_database_split_loading` to test raw data loading and fine
+    tune the dataset split, or its loading.
+
+    Objects of this class behave like a dictionary in which keys are subset
+    names in the split, and values represent samples data and meta-data.
 
 
     Parameters
     ----------
 
-    protocols : list, dict
-        Paths to one or more JSON formatted files containing the various
-        protocols to be recognized by this dataset, or a dictionary, mapping
-        protocol names to paths (or opened file objects) of CSV files.
-        Internally, we save a dictionary where keys default to the basename of
-        paths (list input).
-
-    fieldnames : list, tuple
-        An iterable over the field names (strings) to assign to each entry in
-        the JSON file.  It should have as many items as fields in each entry of
-        the JSON file.
-
-    loader : object
-        A function that receives as input, a context dictionary (with at least
-        a "protocol" and "subset" keys indicating which protocol and subset are
-        being served), and a dictionary with ``{fieldname: value}`` entries,
-        and returns an object with at least 2 attributes:
-
-        * ``key``: which must be a unique string for every sample across
-          subsets in a protocol, and
-        * ``data``: which contains the data associated witht this sample
+    path
+        Absolute path to a JSON formatted file containing the database split to be
+        recognized by this object.
     """
 
-    def __init__(self, protocols, fieldnames):
-        if isinstance(protocols, dict):
-            self._protocols = protocols
-        else:
-            self._protocols = {
-                os.path.basename(
-                    str(k).replace("".join(pathlib.Path(k).suffixes), "")
-                ): k
-                for k in protocols
-            }
-        self.fieldnames = fieldnames
-
-    def check(self, limit=0):
-        """For each protocol, check if all data can be correctly accessed.
-
-        This function assumes each sample has a ``data`` and a ``key``
-        attribute.  The ``key`` attribute should be a string, or representable
-        as such.
-
-
-        Parameters
-        ----------
-
-        limit : int
-            Maximum number of samples to check (in each protocol/subset
-            combination) in this dataset.  If set to zero, then check
-            everything.
-
+    def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
+        if isinstance(path, str):
+            path = pathlib.Path(path)
+        self.path = path
+        self.subsets = self._load_split_from_disk()
 
-        Returns
-        -------
+    def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
+        """Loads all subsets in a split from its file system representation.
 
-        errors : int
-            Number of errors found
-        """
-        logger.info("Checking dataset...")
-        errors = 0
-        for proto in self._protocols:
-            logger.info(f"Checking protocol '{proto}'...")
-            for name, samples in self.subsets(proto).items():
-                logger.info(f"Checking subset '{name}'...")
-                if limit:
-                    logger.info(f"Checking at most first '{limit}' samples...")
-                    samples = samples[:limit]
-                for pos, sample in enumerate(samples):
-                    try:
-                        sample.data  # may trigger data loading
-                        logger.info(f"{sample.key}: OK")
-                    except Exception as e:
-                        logger.error(
-                            f"Found error loading entry {pos} in subset {name} "
-                            f"of protocol {proto} from file "
-                            f"'{self._protocols[proto]}': {e}"
-                        )
-                        errors += 1
-        return errors
-
-    def subsets(self, protocol):
-        """Returns all subsets in a protocol.
-
-        This method will load JSON information for a given protocol and return
-        all subsets of the given protocol after converting each entry through
-        the loader function.
-
-        Parameters
-        ----------
-
-        protocol : str
-            Name of the protocol data to load
+        This method will load JSON information for the current split and return
+        all subsets of the given split after converting each entry through the
+        loader function.
 
 
         Returns
         -------
 
         subsets : dict
-            A dictionary mapping subset names to lists of objects (respecting
-            the ``key``, ``data`` interface).
+            A dictionary mapping subset names to lists of JSON objects
         """
-        fileobj = self._protocols[protocol]
-        if isinstance(fileobj, (str, bytes, pathlib.Path)):
-            if str(fileobj).endswith(".bz2"):
-                import bz2
 
-                with bz2.open(self._protocols[protocol]) as f:
-                    data = json.load(f)
-            else:
-                with open(self._protocols[protocol]) as f:
-                    data = json.load(f)
+        if str(self.path).endswith(".bz2"):
+            logger.debug(f"Loading database split from {str(self.path)}...")
+            with __import__("bz2").open(self.path) as f:
+                return json.load(f)
         else:
-            data = json.load(fileobj)
-            fileobj.seek(0)
+            with self.path.open() as f:
+                return json.load(f)
 
-        retval = {}
-        for subset, samples in data.items():
-            logger.info(f"Loading subset {subset} samples.")
+    def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
+        """Accesses subset ``key`` from this split."""
+        return self.subsets[key]
 
-            retval[subset] = [
-                dict(zip(self.fieldnames, k))
-                for n, k in enumerate(tqdm(samples))
-            ]
+    def __iter__(self):
+        """Iterates over the subsets."""
+        return iter(self.subsets)
+
+    def __len__(self) -> int:
+        """How many subsets we currently have."""
+        return len(self.subsets)
 
-        return retval
 
+class CSVDatabaseSplit(collections.abc.Mapping):
+    """Defines a loader that understands a database split (train, test, etc) in
+    CSV format.
 
-class CSVDataset:
-    """Generic multi-subset filelist dataset that yields samples.
+    To create a new database split, you need to provide one or more CSV
+    formatted files, each representing a subset of this split, containing the
+    sample data (one per row).  Example:
 
-    To create a new dataset, you only need to provide a CSV formatted filelist
-    using any separator (e.g. comma, space, semi-colon) with the following
-    information:
+    Inside the directory ``my-split/``, one can file files ``train.csv``,
+    ``validation.csv``, and ``test.csv``.  Each file has a structure similar to
+    the following:
 
     .. code-block:: text
 
-       value1,value2,value3
-       value4,value5,value6
+       sample1-value1,sample1-value2,sample1-value3
+       sample2-value1,sample2-value2,sample2-value3
        ...
 
-    Notice that all rows must have the same number of entries.
+    Each file in the provided directory defines the subset name on the split.
+    So, the file ``train.csv`` will contain the data from the ``train`` subset,
+    and so on.
+
+    Objects of this class behave like a dictionary in which keys are subset
+    names in the split, and values represent samples data and meta-data.
+
 
     Parameters
     ----------
 
-    subsets : list, dict
-        Paths to one or more CSV formatted files containing the various subsets
-        to be recognized by this dataset, or a dictionary, mapping subset names
-        to paths (or opened file objects) of CSV files.  Internally, we save a
-        dictionary where keys default to the basename of paths (list input).
-
-    fieldnames : list, tuple
-        An iterable over the field names (strings) to assign to each column in
-        the CSV file.  It should have as many items as fields in each row of
-        the CSV file(s).
-
-    loader : object
-        A function that receives as input, a context dictionary (with, at
-        least, a "subset" key indicating which subset is being served), and a
-        dictionary with ``{key: path}`` entries, and returns a dictionary with
-        the loaded data.
+    directory
+        Absolute path to a directory containing the database split layed down
+        as a set of CSV files, one per subset.
     """
 
-    def __init__(self, subsets, fieldnames, loader):
-        if isinstance(subsets, dict):
-            self._subsets = subsets
-        else:
-            self._subsets = {
-                os.path.basename(
-                    str(k).replace("".join(pathlib.Path(k).suffixes), "")
-                ): k
-                for k in subsets
-            }
-        self.fieldnames = fieldnames
-        self._loader = loader
-
-    def check(self, limit=0):
-        """For each subset, check if all data can be correctly accessed.
-
-        This function assumes each sample has a ``data`` and a ``key``
-        attribute.  The ``key`` attribute should be a string, or representable
-        as such.
-
-
-        Parameters
-        ----------
-
-        limit : int
-            Maximum number of samples to check (in each protocol/subset
-            combination) in this dataset.  If set to zero, then check
-            everything.
+    def __init__(
+        self, directory: pathlib.Path | str | importlib.abc.Traversable
+    ):
+        if isinstance(directory, str):
+            directory = pathlib.Path(directory)
+        assert (
+            directory.is_dir()
+        ), f"`{str(directory)}` is not a valid directory"
+        self.directory = directory
+        self.subsets = self._load_split_from_disk()
 
+    def _load_split_from_disk(
+        self,
+    ) -> dict[str, list[typing.Any]]:
+        """Loads all subsets in a split from its file system representation.
 
-        Returns
-        -------
+        This method will load CSV information for the current split and return all
+        subsets of the given split after converting each entry through the
+        loader function.
 
-        errors : int
-            Number of errors found
-        """
-        logger.info("Checking dataset...")
-        errors = 0
-        for name in self._subsets.keys():
-            logger.info(f"Checking subset '{name}'...")
-            samples = self.samples(name)
-            if limit:
-                logger.info(f"Checking at most first '{limit}' samples...")
-                samples = samples[:limit]
-            for pos, sample in enumerate(samples):
-                try:
-                    sample.data  # may trigger data loading
-                    logger.info(f"{sample.key}: OK")
-                except Exception as e:
-                    logger.error(
-                        f"Found error loading entry {pos} in subset {name} "
-                        f"from file '{self._subsets[name]}': {e}"
-                    )
-                    errors += 1
-        return errors
-
-    def subsets(self):
-        """Returns all available subsets at once.
 
         Returns
         -------
 
         subsets : dict
-            A dictionary mapping subset names to lists of objects (respecting
-            the ``key``, ``data`` interface).
-        """
-        return {k: self.samples(k) for k in self._subsets.keys()}
-
-    def samples(self, subset):
-        """Returns all samples in a subset.
-
-        This method will load CSV information for a given subset and return
-        all samples of the given subset after passing each entry through the
-        loading function.
-
-
-        Parameters
-        ----------
-
-        subset : str
-            Name of the subset data to load
-
-
-        Returns
-        -------
-
-        subset : list
-            A lists of objects (respecting the ``key``, ``data`` interface).
+            A dictionary mapping subset names to lists of JSON objects
         """
-        fileobj = self._subsets[subset]
-        if isinstance(fileobj, (str, bytes, pathlib.Path)):
-            with open(self._subsets[subset], newline="") as f:
-                cf = csv.reader(f)
-                samples = [k for k in cf]
-        else:
-            cf = csv.reader(fileobj)
-            samples = [k for k in cf]
-            fileobj.seek(0)
 
-        return [
-            self._loader(
-                dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
-            )
-            for n, k in enumerate(samples)
-        ]
-
-
-class CachedDataset(torch.utils.data.Dataset):
-    def __init__(
-        self, json_protocol, protocol, subset, raw_data_loader, transforms
-    ):
-        self.json_protocol = json_protocol
-        self.subset = subset
-        self.raw_data_loader = raw_data_loader
-        self.transforms = transforms
-
-        self._samples = json_protocol.subsets(protocol)[self.subset]
-        # Dict entry with relative path to files, used during prediction
-        for s in self._samples:
-            s["name"] = s["data"]
-
-        logger.info(f"Caching {self.subset} samples")
-        for sample in tqdm(self._samples):
-            sample["data"] = self.raw_data_loader(sample["data"])
-
-    def __getitem__(self, idx):
-        sample = self._samples[idx].copy()
-        sample["data"] = self.transforms(sample["data"])
-        return sample
-
-    def __len__(self):
-        return len(self._samples)
-
-
-class RuntimeDataset(torch.utils.data.Dataset):
-    def __init__(
-        self, json_protocol, protocol, subset, raw_data_loader, transforms
-    ):
-        self.json_protocol = json_protocol
-        self.subset = subset
-        self.raw_data_loader = raw_data_loader
-        self.transforms = transforms
+        retval = {}
+        for subset in self.directory.iterdir():
+            if str(subset).endswith(".csv.bz2"):
+                logger.debug(f"Loading database split from {subset}...")
+                with __import__("bz2").open(subset) as f:
+                    reader = csv.reader(f)
+                    retval[subset.name[: -len(".csv.bz2")]] = [
+                        k for k in reader
+                    ]
+            elif str(subset).endswith(".csv"):
+                with subset.open() as f:
+                    reader = csv.reader(f)
+                    retval[subset.name[: -len(".csv")]] = [k for k in reader]
+            else:
+                logger.debug(
+                    f"Ignoring file {subset} in CSVDatabaseSplit readout"
+                )
+        return retval
 
-        self._samples = json_protocol.subsets(protocol)[self.subset]
-        # Dict entry with relative path to files
-        for s in self._samples:
-            s["name"] = s["data"]
+    def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
+        """Accesses subset ``key`` from this split."""
+        return self.subsets[key]
 
-    def __getitem__(self, idx):
-        sample = self._samples[idx].copy()
-        sample["data"] = self.transforms(self.raw_data_loader(sample["data"]))
-        return sample
+    def __iter__(self):
+        """Iterates over the subsets."""
+        return iter(self.subsets)
 
-    def __len__(self):
-        return len(self._samples)
+    def __len__(self) -> int:
+        """How many subsets we currently have."""
+        return len(self.subsets)
 
 
-def get_samples_weights(dataset):
-    """Compute the weights of all the samples of the dataset to balance it
-    using the sampler of the dataloader.
+def check_database_split_loading(
+    database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
+    loader: typing.Callable[[typing.Any], torch.Tensor],
+    limit: int = 0,
+) -> int:
+    """For each subset in the split, check if all data can be correctly loaded
+    using the provided loader function.
 
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the weights to balance each class in the dataset and the
-    datasets themselves if we have a ConcatDataset.
+    This function will return the number of errors loading samples, and will
+    log more detailed information to the logging stream.
 
 
     Parameters
     ----------
 
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
+    database_split
+        A mapping that, contains the database split.  Each key represents the
+        name of a subset in the split.  Each value is a (potentially complex)
+        object that represents a single sample.
+
+    loader
+        A callable that transforms sample entries in the database split
+        into :py:class:`torch.Tensor` objects that can be used for training
+        or inference.
+
+    limit
+        Maximum number of samples to check (in each split/subset
+        combination) in this dataset.  If set to zero, then check
+        everything.
 
 
     Returns
     -------
 
-    samples_weights : :py:class:`torch.Tensor`
-        the weights for all the samples in the dataset given as input
+    errors
+        Number of errors found
     """
-    samples_weights = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            # Weighting only for binary labels
-            if isinstance(ds._samples[0]["label"], int):
-                # Groundtruth
-                targets = []
-                for s in ds._samples:
-                    targets.append(s["label"])
-                targets = torch.tensor(targets)
-
-                # Count number of samples per class
-                class_sample_count = torch.tensor(
-                    [
-                        (targets == t).sum()
-                        for t in torch.unique(targets, sorted=True)
-                    ]
+    logger.info(
+        "Checking if can load all samples in all subsets of this split..."
+    )
+    errors = 0
+    for subset in database_split.keys():
+        samples = subset if not limit else subset[:limit]
+        for pos, sample in enumerate(samples):
+            try:
+                data = loader(sample)
+                assert isinstance(data, torch.Tensor)
+            except Exception as e:
+                logger.info(
+                    f"Found error loading entry {pos} in subset `{subset}`: {e}"
                 )
-
-                weight = 1.0 / class_sample_count.float()
-
-                samples_weights.append(
-                    torch.tensor([weight[t] for t in targets])
-                )
-
-            else:
-                # We only weight to sample equally from each dataset
-                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
-
-        # Concatenate sample weights from all the datasets
-        samples_weights = torch.cat(samples_weights)
-
-    else:
-        # Weighting only for binary labels
-        if isinstance(dataset._samples[0]["label"], int):
-            # Groundtruth
-            targets = []
-            for s in dataset._samples:
-                targets.append(s["label"])
-            targets = torch.tensor(targets)
-
-            # Count number of samples per class
-            class_sample_count = torch.tensor(
-                [
-                    (targets == t).sum()
-                    for t in torch.unique(targets, sorted=True)
-                ]
-            )
-
-            weight = 1.0 / class_sample_count.float()
-
-            samples_weights = torch.tensor([weight[t] for t in targets])
-
-        else:
-            # Equal weights for non-binary labels
-            samples_weights = torch.ones(len(dataset._samples))
-
-    return samples_weights
+                errors += 1
+    return errors
 
 
 def get_positive_weights(dataset):
diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py
index 54eb6326..1645962e 100644
--- a/src/ptbench/data/shenzhen/__init__.py
+++ b/src/ptbench/data/shenzhen/__init__.py
@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems.
 * Reference: [MONTGOMERY-SHENZHEN-2014]_
 * Original resolution (height x width or width x height): 3000 x 3000 or less
 * Split reference: none
-* Protocol ``default``:
-
   * Training samples: 64% of TB and healthy CXR (including labels)
   * Validation samples: 16% of TB and healthy CXR (including labels)
   * Test samples: 20% of TB and healthy CXR (including labels)
 """
 import importlib.resources
-import os
-
-from clapper.logging import setup
-
-from ...utils.rc import load_rc
-from ..loader import load_pil_baw
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
 
 _protocols = [
     importlib.resources.files(__name__).joinpath("default.json.bz2"),
@@ -43,11 +32,3 @@ _protocols = [
     importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
     importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
 ]
-
-_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
-
-
-def _raw_data_loader(img_path):
-    raw_data = load_pil_baw(os.path.join(_datadir, img_path))
-
-    return raw_data
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index b5fb23f5..c9068112 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -2,27 +2,34 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Shenzhen dataset for TB detection (default protocol)
+"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.shenzhen` for dataset details
+See :py:mod:`ptbench.data.shenzhen` for dataset details.
+
+This configuration:
+* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
+* augmentations: elastic deformation (probability = 80%)
+* output image resolution: 512x512 pixels
 """
 
-from clapper.logging import setup
+import importlib.resources
 
+from ..datamodule import CachingDataModule
+from ..dataset import JSONDatabaseSplit
 from ..transforms import ElasticDeformation
-from .utils import ShenzhenDataModule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-protocol_name = "default"
-
-augmentation_transforms = [ElasticDeformation(p=0.8)]
-
-datamodule = ShenzhenDataModule(
-    protocol="default",
-    model_transforms=[],
-    augmentation_transforms=augmentation_transforms,
+from .loader import raw_data_loader
+
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__).joinpath("default.json.bz2")
+    ),
+    raw_data_loader=raw_data_loader,
+    cache_samples=False,
+    # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
+    data_augmentations=[ElasticDeformation(p=0.8)],
+    # model_transforms = [],
+    # batch_size = 1,
+    # batch_chunk_count = 1,
+    # drop_incomplete_batch = False,
+    # parallel = -1,
 )
diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py
new file mode 100644
index 00000000..6578fc8f
--- /dev/null
+++ b/src/ptbench/data/shenzhen/loader.py
@@ -0,0 +1,68 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Shenzhen dataset for computer-aided diagnosis.
+
+The standard digital image database for Tuberculosis is created by the National
+Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
+Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
+out-patient clinics, and were captured as part of the daily routine using
+Philips DR Digital Diagnose systems.
+
+* Reference: [MONTGOMERY-SHENZHEN-2014]_
+* Original resolution (height x width or width x height): 3000 x 3000 or less
+* Split reference: none
+* Protocol ``default``:
+
+  * Training samples: 64% of TB and healthy CXR (including labels)
+  * Validation samples: 16% of TB and healthy CXR (including labels)
+  * Test samples: 20% of TB and healthy CXR (including labels)
+"""
+
+import os
+import typing
+
+import torch.nn
+import torchvision.transforms
+
+from ...utils.rc import load_rc
+from ..loader import load_pil_baw
+from ..transforms import RemoveBlackBorders
+
+_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
+"""This variable contains the base directory where the database raw data is
+stored."""
+
+_transform = torchvision.transforms.Compose(
+    [
+        RemoveBlackBorders(),
+        torchvision.transforms.Resize(512),
+        torchvision.transforms.CenterCrop(512),
+        torchvision.transforms.ToTensor(),
+    ]
+)
+"""Transforms that are always applied to the loaded raw images."""
+
+
+def raw_data_loader(
+    sample: tuple[str, int]
+) -> tuple[torch.Tensor, typing.Mapping]:
+    """Loads a single image sample from the disk.
+
+    Parameters
+    ----------
+
+    img_path
+        The path suffix, within the dataset root folder, where to find the
+        image to be loaded.
+
+
+    Returns
+    -------
+
+    image
+        A PIL image in grayscale mode
+    """
+    tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0])))
+    return tensor, dict(label=sample[1])  # type: ignore[arg-type]
diff --git a/src/ptbench/data/shenzhen/utils.py b/src/ptbench/data/shenzhen/utils.py
deleted file mode 100644
index 1521b674..00000000
--- a/src/ptbench/data/shenzhen/utils.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Shenzhen dataset for computer-aided diagnosis.
-
-The standard digital image database for Tuberculosis is created by the
-National Library of Medicine, Maryland, USA in collaboration with Shenzhen
-No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
-The Chest X-rays are from out-patient clinics, and were captured as part of
-the daily routine using Philips DR Digital Diagnose systems.
-
-* Reference: [MONTGOMERY-SHENZHEN-2014]_
-* Original resolution (height x width or width x height): 3000 x 3000 or less
-* Split reference: none
-* Protocol ``default``:
-
-  * Training samples: 64% of TB and healthy CXR (including labels)
-  * Validation samples: 16% of TB and healthy CXR (including labels)
-  * Test samples: 20% of TB and healthy CXR (including labels)
-"""
-from clapper.logging import setup
-from torchvision import transforms
-
-from ..base_datamodule import BaseDataModule
-from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
-from ..shenzhen import _protocols, _raw_data_loader
-from ..transforms import RemoveBlackBorders
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class ShenzhenDataModule(BaseDataModule):
-    def __init__(
-        self,
-        protocol="default",
-        model_transforms=[],
-        augmentation_transforms=[],
-        batch_size=1,
-        batch_chunk_count=1,
-        drop_incomplete_batch=False,
-        cache_samples=False,
-        parallel=-1,
-    ):
-        super().__init__(
-            batch_size=batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            batch_chunk_count=batch_chunk_count,
-            parallel=parallel,
-        )
-
-        self._cache_samples = cache_samples
-        self._has_setup_fit = False
-        self._has_setup_predict = False
-        self._protocol = protocol
-
-        self.raw_data_transforms = [
-            RemoveBlackBorders(),
-            transforms.Resize(512),
-            transforms.CenterCrop(512),
-            transforms.ToTensor(),
-        ]
-
-        self.model_transforms = model_transforms
-
-        self.augmentation_transforms = augmentation_transforms
-
-    def setup(self, stage: str):
-        json_protocol = JSONProtocol(
-            protocols=_protocols,
-            fieldnames=("data", "label"),
-        )
-
-        if self._cache_samples:
-            dataset = CachedDataset
-        else:
-            dataset = RuntimeDataset
-
-        if not self._has_setup_fit and stage == "fit":
-            self.train_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "train",
-                _raw_data_loader,
-                self._build_transforms(is_train=True),
-            )
-
-            self.validation_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "validation",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-
-            self._has_setup_fit = True
-
-        if not self._has_setup_predict and stage == "predict":
-            self.train_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "train",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-            self.validation_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "validation",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-
-            self._has_setup_predict = True
diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index 6d3c17d2..c85c3e9d 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -58,6 +58,8 @@ class RemoveBlackBorders:
 class ElasticDeformation:
     """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
 
+    TODO: needs to be converted into a torch.nn.Module to become scriptable!
+
     Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
     """
 
-- 
GitLab