diff --git a/doc/references.rst b/doc/references.rst
index 5a349fa9dd960aa393f8f9d5dfd6f696676317e6..056707406d9e3dbf5d73fe5bc03cdd7f49701cbb 100644
--- a/doc/references.rst
+++ b/doc/references.rst
@@ -1,5 +1,6 @@
-
-.. coding=utf-8
+.. SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+..
+.. SPDX-License-Identifier: GPL-3.0-or-later
 
 ============
  References
diff --git a/pyproject.toml b/pyproject.toml
index bb93be327219f1e58eadc46317074d91eb3235e8..7ce5443548baaf829fa3c877ef8a35f805554ed4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,33 +15,33 @@ dynamic = ["readme"]
 license = { text = "GNU General Public License v3 (GPLv3)" }
 authors = [{ name = "Geoffrey Raposo", email = "geoffrey@raposo.ch" }]
 maintainers = [
-    { name = "Andre Anjos", email = "andre.anjos@idiap.ch" },
-    { name = "Daniel Carron", email = "daniel.carron@idiap.ch" },
+  { name = "Andre Anjos", email = "andre.anjos@idiap.ch" },
+  { name = "Daniel Carron", email = "daniel.carron@idiap.ch" },
 ]
 classifiers = [
-    "Development Status :: 4 - Beta",
-    "Intended Audience :: Developers",
-    "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
-    "Natural Language :: English",
-    "Programming Language :: Python :: 3",
-    "Topic :: Software Development :: Libraries :: Python Modules",
+  "Development Status :: 4 - Beta",
+  "Intended Audience :: Developers",
+  "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
+  "Natural Language :: English",
+  "Programming Language :: Python :: 3",
+  "Topic :: Software Development :: Libraries :: Python Modules",
 ]
 dependencies = [
-    "clapper",
-    "click",
-    "numpy",
-    "pandas",
-    "scipy",
-    "scikit-learn",
-    "tqdm",
-    "psutil",
-    "tabulate",
-    "matplotlib",
-    "pillow",
-    "torch>=1.8",
-    "torchvision>=0.10",
-    "lightning>=2.0.3",
-    "tensorboard",
+  "clapper",
+  "click",
+  "numpy",
+  "pandas",
+  "scipy",
+  "scikit-learn",
+  "tqdm",
+  "psutil",
+  "tabulate",
+  "matplotlib",
+  "pillow",
+  "torch>=1.8",
+  "torchvision>=0.10",
+  "lightning>=2.0.3",
+  "tensorboard",
 ]
 
 [project.urls]
@@ -53,13 +53,13 @@ changelog = "https://gitlab.idiap.ch/biosignal/software/ptbench/-/releases"
 [project.optional-dependencies]
 qa = ["pre-commit"]
 doc = [
-    "sphinx",
-    "furo",
-    "sphinx-autodoc-typehints",
-    "auto-intersphinx",
-    "sphinx-copybutton",
-    "sphinx-inline-tabs",
-    "sphinx-click",
+  "sphinx",
+  "furo",
+  "sphinx-autodoc-typehints",
+  "auto-intersphinx",
+  "sphinx-copybutton",
+  "sphinx-inline-tabs",
+  "sphinx-click",
 ]
 test = ["pytest", "pytest-cov", "coverage"]
 
diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py
index cf8bfd35aa10ad3493be9483ba0347fc8ebb7da1..2361b886d500fee740e456f2505a36da4fdaf4e3 100644
--- a/src/ptbench/configs/models/alexnet.py
+++ b/src/ptbench/configs/models/alexnet.py
@@ -6,19 +6,30 @@
 
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
+from torch.optim import SGD
 
 from ...models.alexnet import Alexnet
 
-# config
+# optimizer
+optimizer = SGD
 optimizer_configs = {"lr": 0.01, "momentum": 0.1}
 
-# optimizer
-optimizer = "SGD"
 # criterion
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+from ...data.transforms import ElasticDeformation
+
+augmentation_transforms = [
+    ElasticDeformation(p=0.8),
+]
+
 # model
 model = Alexnet(
-    criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
+    criterion,
+    criterion_valid,
+    optimizer,
+    optimizer_configs,
+    pretrained=False,
+    augmentation_transforms=augmentation_transforms,
 )
diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py
index 1d196be6f79ea5c70987c1d1a66eaf32e8e7ca4c..0dc7e5d67d007cf5e7e358e7fa75243a47047c4b 100644
--- a/src/ptbench/configs/models/alexnet_pretrained.py
+++ b/src/ptbench/configs/models/alexnet_pretrained.py
@@ -6,19 +6,30 @@
 
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
+from torch.optim import SGD
 
 from ...models.alexnet import Alexnet
 
-# config
-optimizer_configs = {"lr": 0.001, "momentum": 0.1}
-
 # optimizer
-optimizer = "SGD"
+optimizer = SGD
+optimizer_configs = {"lr": 0.01, "momentum": 0.1}
+
 # criterion
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+from ...data.transforms import ElasticDeformation
+
+augmentation_transforms = [
+    ElasticDeformation(p=0.8),
+]
+
 # model
 model = Alexnet(
-    criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True
+    criterion,
+    criterion_valid,
+    optimizer,
+    optimizer_configs,
+    pretrained=True,
+    augmentation_transforms=augmentation_transforms,
 )
diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py
index 6759490854fd05ac8d8d3a9eec5b1494e0cfb0f2..5d612b2a18146ba306b419a808936e9c3c7042f7 100644
--- a/src/ptbench/configs/models/densenet.py
+++ b/src/ptbench/configs/models/densenet.py
@@ -6,20 +6,30 @@
 
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
+from torch.optim import Adam
 
 from ...models.densenet import Densenet
 
-# config
-optimizer_configs = {"lr": 0.0001}
-
 # optimizer
-optimizer = "Adam"
+optimizer = Adam
+optimizer_configs = {"lr": 0.0001}
 
 # criterion
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+from ...data.transforms import ElasticDeformation
+
+augmentation_transforms = [
+    ElasticDeformation(p=0.8),
+]
+
 # model
 model = Densenet(
-    criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
+    criterion,
+    criterion_valid,
+    optimizer,
+    optimizer_configs,
+    pretrained=False,
+    augmentation_transforms=augmentation_transforms,
 )
diff --git a/src/ptbench/configs/models/densenet_pretrained.py b/src/ptbench/configs/models/densenet_pretrained.py
index b018a52203061b847cdae9f09b5edfa713930302..f8908fdb1e87a62df41dca0ecb75ff1fc79b1012 100644
--- a/src/ptbench/configs/models/densenet_pretrained.py
+++ b/src/ptbench/configs/models/densenet_pretrained.py
@@ -6,20 +6,30 @@
 
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
+from torch.optim import Adam
 
 from ...models.densenet import Densenet
 
-# config
-optimizer_configs = {"lr": 0.01}
-
 # optimizer
-optimizer = "Adam"
+optimizer = Adam
+optimizer_configs = {"lr": 0.0001}
 
 # criterion
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
+from ...data.transforms import ElasticDeformation
+
+augmentation_transforms = [
+    ElasticDeformation(p=0.8),
+]
+
 # model
 model = Densenet(
-    criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True
+    criterion,
+    criterion_valid,
+    optimizer,
+    optimizer_configs,
+    pretrained=True,
+    augmentation_transforms=augmentation_transforms,
 )
diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index 3ee0b92164b5531b65049b94e71b01b07e2ad27e..d1e1b0a3ae8d9e3e32a7ec19a49e21f01bb694d9 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -11,20 +11,16 @@ Screening and Visualization".
 Reference: [PASA-2019]_
 """
 
-from torch import empty
 from torch.nn import BCEWithLogitsLoss
-
-from ...models.pasa import PASA
-
-# config
-optimizer_configs = {"lr": 8e-5}
-
-# optimizer
-optimizer = "Adam"
-
-# criterion
-criterion = BCEWithLogitsLoss(pos_weight=empty(1))
-criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
-
-# model
-model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
+from torch.optim import Adam
+
+from ...data.transforms import ElasticDeformation
+from ...models.pasa import Pasa
+
+model = Pasa(
+    train_loss=BCEWithLogitsLoss(),
+    validation_loss=BCEWithLogitsLoss(),
+    optimizer_type=Adam,
+    optimizer_arguments=dict(lr=8e-5),
+    augmentation_transforms=[ElasticDeformation(p=0.8)],
+)
diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
deleted file mode 100644
index 8377c66328241f549dbb2f10946f9cc973ef7a6f..0000000000000000000000000000000000000000
--- a/src/ptbench/data/base_datamodule.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-import multiprocessing
-import sys
-
-import lightning as pl
-import torch
-
-from clapper.logging import setup
-from torch.utils.data import DataLoader, WeightedRandomSampler
-from torchvision import transforms
-
-from .dataset import get_samples_weights
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class BaseDataModule(pl.LightningDataModule):
-    def __init__(
-        self,
-        batch_size=1,
-        batch_chunk_count=1,
-        drop_incomplete_batch=False,
-        parallel=-1,
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.batch_chunk_count = batch_chunk_count
-
-        self.train_dataset = None
-        self.validation_dataset = None
-        self.extra_validation_datasets = None
-
-        self._raw_data_transforms = []
-        self._model_transforms = []
-        self._augmentation_transforms = []
-
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.pin_memory = (
-            torch.cuda.is_available()
-        )  # should only be true if GPU available and using it
-
-        self.parallel = parallel
-
-    def setup(self, stage: str):
-        # Implemented by user
-        # Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
-        raise NotImplementedError
-
-    def train_dataloader(self):
-        train_samples_weights = get_samples_weights(self.train_dataset)
-
-        multiproc_kwargs = self._setup_multiproc(self.parallel)
-        train_sampler = WeightedRandomSampler(
-            train_samples_weights, len(train_samples_weights), replacement=True
-        )
-
-        return DataLoader(
-            self.train_dataset,
-            batch_size=self._compute_chunk_size(
-                self.batch_size, self.batch_chunk_count
-            ),
-            drop_last=self.drop_incomplete_batch,
-            pin_memory=self.pin_memory,
-            sampler=train_sampler,
-            **multiproc_kwargs,
-        )
-
-    def val_dataloader(self):
-        loaders_dict = {}
-
-        multiproc_kwargs = self._setup_multiproc(self.parallel)
-
-        val_loader = DataLoader(
-            dataset=self.validation_dataset,
-            batch_size=self._compute_chunk_size(
-                self.batch_size, self.batch_chunk_count
-            ),
-            shuffle=False,
-            drop_last=False,
-            pin_memory=self.pin_memory,
-            **multiproc_kwargs,
-        )
-
-        loaders_dict["validation_loader"] = val_loader
-
-        if self.extra_validation_datasets is not None:
-            for set_idx, extra_set in enumerate(self.extra_validation_datasets):
-                extra_val_loader = DataLoader(
-                    dataset=extra_set,
-                    batch_size=self._compute_chunk_size(
-                        self.batch_size, self.batch_chunk_count
-                    ),
-                    shuffle=False,
-                    drop_last=False,
-                    pin_memory=self.pin_memory,
-                    **multiproc_kwargs,
-                )
-
-                loaders_dict[
-                    f"extra_validation_loader{set_idx}"
-                ] = extra_val_loader
-
-        return loaders_dict
-
-    def predict_dataloader(self):
-        loaders_dict = {}
-
-        loaders_dict["train_loader"] = self.train_dataloader()
-        for k, v in self.val_dataloader().items():
-            loaders_dict[k] = v
-
-        return loaders_dict
-
-    def update_module_properties(self, **kwargs):
-        for k, v in kwargs.items():
-            setattr(self, k, v)
-
-    def _compute_chunk_size(self, batch_size, chunk_count):
-        batch_chunk_size = batch_size
-        if batch_size % chunk_count != 0:
-            # batch_size must be divisible by batch_chunk_count.
-            raise RuntimeError(
-                f"--batch-size ({batch_size}) must be divisible by "
-                f"--batch-chunk-size ({chunk_count})."
-            )
-        else:
-            batch_chunk_size = batch_size // chunk_count
-
-        return batch_chunk_size
-
-    def _setup_multiproc(self, parallel):
-        multiproc_kwargs = dict()
-        if parallel < 0:
-            multiproc_kwargs["num_workers"] = 0
-        else:
-            multiproc_kwargs["num_workers"] = (
-                parallel or multiprocessing.cpu_count()
-            )
-
-        if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
-            multiproc_kwargs[
-                "multiprocessing_context"
-            ] = multiprocessing.get_context("spawn")
-
-        return multiproc_kwargs
-
-    def _build_transforms(self, is_train):
-        all_transforms = self.raw_data_transforms + self.model_transforms
-
-        if is_train:
-            all_transforms = all_transforms + self.augmentation_transforms
-
-        # all_transforms.append(transforms.ToTensor())
-
-        return transforms.Compose(all_transforms)
-
-    @property
-    def raw_data_transforms(self):
-        return self._raw_data_transforms
-
-    @raw_data_transforms.setter
-    def raw_data_transforms(self, transforms):
-        self._raw_data_transforms = transforms
-
-    @property
-    def model_transforms(self):
-        return self._model_transforms
-
-    @model_transforms.setter
-    def model_transforms(self, transforms):
-        self._model_transforms = transforms
-
-    @property
-    def augmentation_transforms(self):
-        return self._augmentation_transforms
-
-    @augmentation_transforms.setter
-    def augmentation_transforms(self, transforms):
-        self._augmentation_transforms = transforms
-
-
-def get_dataset_from_module(module, stage, **module_args):
-    """Instantiates a DataModule and retrieves the corresponding dataset.
-
-    Useful when combining multiple datasets.
-    """
-    module_instance = module(**module_args)
-    module_instance.prepare_data()
-    module_instance.setup(stage=stage)
-    dataset = module_instance.dataset
-
-    return dataset
diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..abcf11d79286ecdb16cc374fb920f68418ec603a
--- /dev/null
+++ b/src/ptbench/data/datamodule.py
@@ -0,0 +1,722 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import collections
+import logging
+import multiprocessing
+import sys
+import typing
+
+import lightning
+import torch
+import torch.backends
+import torch.utils.data
+import torchvision.transforms
+import tqdm
+
+from .typing import (
+    DatabaseSplit,
+    DataLoader,
+    Dataset,
+    RawDataLoader,
+    Sample,
+    Transform,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _setup_dataloader_multiproc_parameters(
+    parallel: int,
+) -> dict[str, typing.Any]:
+    """Returns a dictionary containing pytorch arguments to be used in data
+    loaders.
+
+    It sets the parameter ``num_workers`` to match the expected pytorch
+    representation.  For macOS machines, it also sets the
+    ``multiprocessing_context`` to use ``spawn`` instead of the default.
+
+    The mapping between the command-line interface ``parallel`` setting works
+    like this:
+
+    .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
+       :widths: 15 15 70
+       :header-rows: 1
+
+       * - CLI ``parallel``
+         - :py:class:`torch.utils.data.DataLoader` ``kwargs``
+         - Comments
+       * - ``<0``
+         - 0
+         - Disables multiprocessing entirely, executes everything within the
+           same processing context
+       * - ``0``
+         - :py:func:`multiprocessing.cpu_count`
+         - Runs mini-batch data loading on as many external processes as CPUs
+           available in the current machine
+       * - ``>=1``
+         - ``parallel``
+         - Runs mini-batch data loading on as many external processes as set on
+           ``parallel``
+    """
+
+    retval: dict[str, typing.Any] = dict()
+    if parallel < 0:
+        retval["num_workers"] = 0
+    else:
+        retval["num_workers"] = parallel or multiprocessing.cpu_count()
+
+    if retval["num_workers"] > 0 and sys.platform == "darwin":
+        retval["multiprocessing_context"] = multiprocessing.get_context("spawn")
+
+    return retval
+
+
+class _DelayedLoadingDataset(Dataset):
+    """A list that loads its samples on demand.
+
+    This list mimics a pytorch Dataset, except raw data loading is done
+    on-the-fly, as the samples are requested through the bracket operator.
+
+
+    Parameters
+    ----------
+
+    split
+        An iterable containing the raw dataset samples loaded from the database
+        splits.
+
+    loader
+        An object instance that can load samples and labels from storage.
+
+    transforms
+        A set of transforms that should be applied on-the-fly for this dataset,
+        to fit the output of the raw-data-loader to the model of interest.
+    """
+
+    def __init__(
+        self,
+        split: typing.Sequence[typing.Any],
+        loader: RawDataLoader,
+        transforms: typing.Sequence[Transform] = [],
+    ):
+        self.split = split
+        self.loader = loader
+        self.transform = torchvision.transforms.Compose(transforms)
+
+    def labels(self) -> list[int]:
+        """Returns the integer labels for all samples in the dataset."""
+        return [self.loader.label(k) for k in self.split]
+
+    def __getitem__(self, key: int) -> Sample:
+        tensor, metadata = self.loader.sample(self.split[key])
+        return self.transform(tensor), metadata
+
+    def __len__(self):
+        return len(self.split)
+
+    def __iter__(self):
+        for x in range(len(self)):
+            yield self[x]
+
+
+class _CachedDataset(Dataset):
+    """Basically, a list of preloaded samples.
+
+    This dataset will load all samples from the split during construction
+    instead of delaying that to the indexing.  Beyong raw-data-loading,
+    ``transforms`` given upon construction contribute to the cached samples.
+
+
+    Parameters
+    ----------
+
+    split
+        An iterable containing the raw dataset samples loaded from the database
+        splits.
+
+    loader
+        An object instance that can load samples and labels from storage.
+
+    parallel
+        Use multiprocessing for data loading: if set to -1 (default), disables
+        multiprocessing data loading.  Set to 0 to enable as many data loading
+        instances as processing cores as available in the system.  Set to >= 1
+        to enable that many multiprocessing instances for data loading.
+
+    transforms
+        A set of transforms that should be applied to the cached samples for
+        this dataset, to fit the output of the raw-data-loader to the model of
+        interest.
+    """
+
+    def __init__(
+        self,
+        split: typing.Sequence[typing.Any],
+        loader: RawDataLoader,
+        parallel: int = -1,
+        transforms: typing.Sequence[Transform] = [],
+    ):
+        self.transform = torchvision.transforms.Compose(transforms)
+
+        if parallel < 0:
+            self.data = [
+                loader.sample(k) for k in tqdm.tqdm(split, unit="sample")
+            ]
+        else:
+            instances = parallel or multiprocessing.cpu_count()
+            logger.info(f"Caching dataset using {instances} processes...")
+            with multiprocessing.Pool(instances) as p:
+                self.data = list(
+                    tqdm.tqdm(p.imap(loader.sample, split), total=len(split))
+                )
+
+    def labels(self) -> list[int]:
+        """Returns the integer labels for all samples in the dataset."""
+        return [k[1]["label"] for k in self.data]
+
+    def __getitem__(self, key: int) -> Sample:
+        tensor, metadata = self.data[key]
+        return self.transform(tensor), metadata
+
+    def __len__(self):
+        return len(self.data)
+
+    def __iter__(self):
+        for x in range(len(self)):
+            yield self[x]
+
+
+def _make_balanced_random_sampler(
+    dataset: Dataset,
+    target: str = "label",
+) -> torch.utils.data.WeightedRandomSampler:
+    """Generates a pytorch sampler that samples according to class
+    probabilities.
+
+    This function takes as input a torch Dataset, and computes the weights to
+    balance each class in the dataset, and the datasets themselves if one
+    passes a :py:class:`torch.utils.data.ConcatDataset`.
+
+    In this implementation, we balance **both** class and dataset-origin
+    probabilities, what you expect for a truly *equitable* random sampler.
+
+    Take this example for illustration:
+
+    * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1
+    * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1
+
+    So:
+
+    | Dataset | Target | Samples | Weight | Normalised weight |
+    +---------+--------+---------+--------+-------------------+
+    |    1    |   0    |    9    |  1/9   |      1/36         |
+    |    1    |   1    |    1    |  1/1   |      1/4          |
+    |    2    |   0    |    3    |  1/3   |      1/12         |
+    |    2    |   1    |    3    |  1/3   |      1/12         |
+
+    Legend:
+
+    * Weight: the weights computed by this method
+    * Normalised weight: the weight per sample used by the random sampler,
+      after normalising the weights by the sum of all weights in the
+      concatenated dataset, such that the sum of all normalized weights times
+      the number of samples is 1.
+
+    The properties of this algorithm are as follows:
+
+    1. The probability of picking a sample from any target is the same (0.5 in
+       this case).  To verify this, notice that the probability of picking a
+       sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
+    2. The probability of picking a sample with ``target=0`` from Dataset 2 is
+       3 times higher than those from Dataset 1.  As there are 3 times less
+       samples in Dataset 2 with ``target=0``, this makes choosing samples from
+       Dataset 1 proportionally less likely.
+    3. The probability of picking a sample with ``target=1`` from Dataset 2 is
+       3 times lower than those from Dataset 1.  As there are 3 times less
+       samples in Dataset 1 with ``target=1``, this makes choosing samples from
+       Dataset 2 proportionally less likely.
+
+    This function assumes targets are stored on a dictionary entry named
+    ``target`` inside the metadata information for the :py:type:``Sample``, and
+    that its value is integer.
+
+    We then instantiate a pytorch sampler using the inverse probabilities (the
+    more samples of a class, the less likely it becomes to be sampled.
+
+
+    Parameters
+    ----------
+
+    dataset
+        An instance of torch Dataset.
+        :py:class:`torch.utils.data.ConcatDataset` are supported.
+
+    target
+        The name of a metadata key pointing to an integer property that allows
+        balancing the dataset.
+
+
+    Returns
+    -------
+
+    sampler
+        A sampler, to be used in a dataloader equipped with the same dataset
+        used to calculate the relative sample weights.
+
+
+    Raises
+    ------
+
+    RuntimeError
+        If requested to balance a dataset (single, not-concatenated) without an
+        existing target.
+    """
+
+    def _calculate_weights(targets: list[int]) -> list[float]:
+        counts = collections.Counter(targets)
+        weights = {k: 1.0 / v for k, v in counts.items()}
+        return [weights[k] for k in targets]
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        # There are two possible cases: targets/no-targets
+        metadata_example = dataset.datasets[0][0][1]
+        if target in metadata_example and isinstance(
+            metadata_example[target], int
+        ):
+            # there are integer targets, let's balance with those
+            logger.info(
+                f"Balancing sample selection probabilities **and** "
+                f"concatenated-datasets using metadata targets `{target}`"
+            )
+            targets = [
+                k
+                for ds in dataset.datasets
+                for k in typing.cast(Dataset, ds).labels()
+            ]
+            weights = _calculate_weights(targets)
+        else:
+            logger.warning(
+                f"Balancing samples **and** concatenated-datasets "
+                f"WITHOUT metadata targets (`{target}` not available)"
+            )
+            weights = [
+                k
+                for ds in dataset.datasets
+                for k in len(typing.cast(typing.Sized, ds))
+                * [1.0 / len(typing.cast(typing.Sized, ds))]
+            ]
+
+            pass
+
+    else:
+        metadata_example = dataset[0][1]
+        if target in metadata_example and isinstance(
+            metadata_example[target], int
+        ):
+            logger.info(
+                f"Balancing samples from dataset using metadata "
+                f"targets `{target}`"
+            )
+            weights = _calculate_weights(dataset.labels())
+        else:
+            raise RuntimeError(
+                f"Cannot balance samples without metadata targets `{target}`"
+            )
+
+    return torch.utils.data.WeightedRandomSampler(
+        weights, len(weights), replacement=True
+    )
+
+
+class CachingDataModule(lightning.LightningDataModule):
+    """A conveninent data module with CSV or JSON protocol loading, mini-
+    batching, parallelisation and caching, all in one.
+
+    Instances of this class load data-split (a.k.a. protocol) definitions for a
+    database, and can load the data from the disk.  An optional caching
+    mechanism stores the data at associated CPU memory, which can improve data
+    serving while training and evaluating models.
+
+    This datamodule defines basic operations to handle data loading and
+    mini-batch handling within this package's framework.  It can return
+    :py:class:`torch.utils.data.DataLoader` objects for training, validation,
+    prediction and testing conditions.  Parallelisation is handled by a simple
+    input flag.
+
+    Users must implement the basic :py:meth:`setup` function, which is
+    parameterised by a single string enumeration containing: ``fit``,
+    ``validate``, ``test``, or ``predict``.
+
+
+    Parameters
+    ----------
+
+    database_split
+        A dictionary that contains string keys representing subset names, and
+        values that are iterables over sample representations (potentially on
+        disk).  These objects are passed to the ``sample_loader`` for loading
+        the sample data (and metadata) in memory.  The objects represented may
+        be of any format (e.g. list, dictionary, etc), for as long as the
+        ``sample_loader`` can properly handle it.  To check the split and the
+        loader function works correctly, you may use
+        :py:func:`..dataset.check_database_split_loading`.  As is, this class
+        expects at least one entry called ``train`` to exist in the input
+        dictionary.  Optional entries are ``validation``, and ``test``. Entries
+        named ``monitor-...`` will be considered extra subsets that do not
+        influence any early stop criteria during training, and are just
+        monitored beyond the ``validation`` dataset.
+
+    loader
+        An object instance that can load samples and labels from storage.
+
+    cache_samples
+        If set, then issue raw data loading during ``prepare_data()``, and
+        serves samples from CPU memory.  Otherwise, loads samples from disk on
+        demand. Running from CPU memory will offer increased speeds in exchange
+        for CPU memory.  Sufficient CPU memory must be available before you set
+        this attribute to ``True``.  It is typicall useful for relatively small
+        datasets.
+
+    balance_sampler_by_class
+        If set, then modifies the random sampler used during training and
+        validation to balance sample picking probability, making sample
+        across classes **and** datasets equitable.
+
+    model_transforms
+        A list of transforms (torch modules) that will be applied after
+        raw-data-loading, and just before data is fed into the model or
+        eventual data-augmentation transformations for all data loaders
+        produced by this data module.  This part of the pipeline receives data
+        as output by the raw-data-loader, or model-related transforms (e.g.
+        resize adaptions), if any is specified.
+
+    batch_size
+        Number of samples in every **training** batch (this parameter affects
+        memory requirements for the network).  If the number of samples in the
+        batch is larger than the total number of samples available for
+        training, this value is truncated.  If this number is smaller, then
+        batches of the specified size are created and fed to the network  until
+        there are no more new samples to feed (epoch is finished).  If the
+        total number of training samples is not a multiple of the batch-size,
+        the last batch will be smaller than the first, unless
+        ``drop_incomplete_batch`` is set to ``true``, in which case this batch
+        is not used.
+
+    batch_chunk_count
+        Number of chunks in every batch (this parameter affects memory
+        requirements for the network). The number of samples loaded for every
+        iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
+        needs to be divisible by ``batch_chunk_count``, otherwise an error will
+        be raised. This parameter is used to reduce number of samples loaded in
+        each iteration, in order to reduce the memory usage in exchange for
+        processing time (more iterations). This is specially interesting whe
+        one is running with GPUs with limited RAM. The default of 1 forces the
+        whole batch to be processed at once. Otherwise the batch is broken into
+        batch-chunk-count pieces, and gradients are accumulated to complete
+        each batch.
+
+    drop_incomplete_batch
+        If set, then may drop the last batch in an epoch, in case it is
+        incomplete.  If you set this option, you should also consider
+        increasing the total number of epochs of training, as the total number
+        of training steps may be reduced.
+
+    parallel
+        Use multiprocessing for data loading: if set to -1 (default), disables
+        multiprocessing data loading.  Set to 0 to enable as many data loading
+        instances as processing cores as available in the system.  Set to >= 1
+        to enable that many multiprocessing instances for data loading.
+    """
+
+    DatasetDictionary = dict[str, Dataset]
+
+    def __init__(
+        self,
+        database_split: DatabaseSplit,
+        raw_data_loader: RawDataLoader,
+        cache_samples: bool = False,
+        balance_sampler_by_class: bool = False,
+        model_transforms: list[Transform] = [],
+        batch_size: int = 1,
+        batch_chunk_count: int = 1,
+        drop_incomplete_batch: bool = False,
+        parallel: int = -1,
+    ):
+        super().__init__()
+
+        self.set_chunk_size(batch_size, batch_chunk_count)
+
+        self.database_split = database_split
+        self.raw_data_loader = raw_data_loader
+        self.cache_samples = cache_samples
+        self._train_sampler = None
+        self.balance_sampler_by_class = balance_sampler_by_class
+        self.model_transforms = model_transforms
+
+        self.drop_incomplete_batch = drop_incomplete_batch
+        self.parallel = parallel  # immutable, otherwise would need to call
+
+        self.pin_memory = (
+            torch.cuda.is_available() or torch.backends.mps.is_available()
+        )  # should only be true if GPU available and using it
+
+        # datasets that have been setup() for the current stage
+        self._datasets: CachingDataModule.DatasetDictionary = {}
+
+    @property
+    def parallel(self) -> int:
+        """Whether to use multiprocessing for data loading.
+
+        Use multiprocessing for data loading: if set to -1 (default),
+        disables multiprocessing data loading.  Set to 0 to enable as
+        many data loading instances as processing cores as available in
+        the system.  Set to >= 1 to enable that many multiprocessing
+        instances for data loading.
+        """
+        return self._parallel
+
+    @parallel.setter
+    def parallel(self, value: int) -> None:
+        self._parallel = value
+        self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
+            value
+        )
+        # datasets that have been setup() for the current stage
+        self._datasets = {}
+
+    @property
+    def balance_sampler_by_class(self):
+        """Whether to balance samples across labels/datasets.
+
+        If set, then modifies the random sampler used during training
+        and validation to balance sample picking probability, making
+        sample across classes **and** datasets equitable.
+        """
+        return self._train_sampler is not None
+
+    @balance_sampler_by_class.setter
+    def balance_sampler_by_class(self, value: bool):
+        if value:
+            if "train" not in self._datasets:
+                self._setup_dataset("train")
+            self._train_sampler = _make_balanced_random_sampler(
+                self._datasets["train"]
+            )
+        else:
+            self._train_sampler = None
+
+    def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
+        """Coherently sets the batch-chunk-size after validation.
+
+        Parameters
+        ----------
+
+        batch_size
+            Number of samples in every **training** batch (this parameter affects
+            memory requirements for the network).  If the number of samples in the
+            batch is larger than the total number of samples available for
+            training, this value is truncated.  If this number is smaller, then
+            batches of the specified size are created and fed to the network  until
+            there are no more new samples to feed (epoch is finished).  If the
+            total number of training samples is not a multiple of the batch-size,
+            the last batch will be smaller than the first, unless
+            ``drop_incomplete_batch`` is set to ``true``, in which case this batch
+            is not used.
+
+        batch_chunk_count
+            Number of chunks in every batch (this parameter affects memory
+            requirements for the network). The number of samples loaded for every
+            iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
+            needs to be divisible by ``batch_chunk_count``, otherwise an error will
+            be raised. This parameter is used to reduce number of samples loaded in
+            each iteration, in order to reduce the memory usage in exchange for
+            processing time (more iterations). This is specially interesting whe
+            one is running with GPUs with limited RAM. The default of 1 forces the
+            whole batch to be processed at once. Otherwise the batch is broken into
+            batch-chunk-count pieces, and gradients are accumulated to complete
+            each batch.
+        """
+
+        # validation
+        if batch_size % batch_chunk_count != 0:
+            raise RuntimeError(
+                f"batch_size ({batch_size}) must be divisible by "
+                f"batch_chunk_size ({batch_chunk_count})."
+            )
+
+        self._batch_size = batch_size
+        self._batch_chunk_count = batch_chunk_count
+        self._chunk_size = self._batch_size // self._batch_chunk_count
+
+    def _setup_dataset(self, name: str) -> None:
+        """Sets-up a single dataset from the input data split.
+
+        Parameters
+        ----------
+
+        name
+            Name of the dataset to setup.
+        """
+
+        if name in self._datasets:
+            logger.info(
+                f"Dataset `{name}` is already setup. "
+                f"Not re-instantiating it."
+            )
+            return
+        if self.cache_samples:
+            logger.info(
+                f"Loading dataset:`{name}` into memory (caching)."
+                f" Trade-off: CPU RAM: more | Disk: less"
+            )
+            self._datasets[name] = _CachedDataset(
+                self.database_split[name],
+                self.raw_data_loader,
+                self.parallel,
+                self.model_transforms,
+            )
+        else:
+            logger.info(
+                f"Loading dataset:`{name}` without caching."
+                f" Trade-off: CPU RAM: less | Disk: more"
+            )
+            self._datasets[name] = _DelayedLoadingDataset(
+                self.database_split[name],
+                self.raw_data_loader,
+                self.model_transforms,
+            )
+
+    def _val_dataset_keys(self) -> list[str]:
+        """Returns list of validation dataset names."""
+        return ["validation"] + [
+            k for k in self.database_split.keys() if k.startswith("monitor-")
+        ]
+
+    def setup(self, stage: str) -> None:
+        """Sets up datasets for different tasks on the pipeline.
+
+        This method should setup (load, pre-process, etc) all datasets required
+        for a particular ``stage`` (fit, validate, test, predict), and keep
+        them ready to be used on one of the `_dataloader()` functions that are
+        pertinent for such stage.
+
+        If you have set ``cache_samples``, samples are loaded at this stage and
+        cached in memory.
+
+
+        Parameters
+        ----------
+
+        stage
+            Name of the stage to which the setup is applicable.  Can be one of
+            ``fit``, ``validate``, ``test`` or ``predict``.  Each stage
+            typically uses the following data loaders:
+
+            * ``fit``: uses both train and validation datasets
+            * ``validate``: uses only the validation dataset
+            * ``test``: uses only the test dataset
+            * ``predict``: uses only the test dataset
+        """
+
+        if stage == "fit":
+            for k in ["train"] + self._val_dataset_keys():
+                self._setup_dataset(k)
+
+        elif stage == "validate":
+            for k in self._val_dataset_keys():
+                self._setup_dataset(k)
+
+        elif stage == "test":
+            self._setup_dataset("test")
+
+        elif stage == "predict":
+            self._setup_dataset("test")
+
+    def teardown(self, stage: str) -> None:
+        """Unset-up datasets for different tasks on the pipeline.
+
+        This method unsets (unload, remove from memory, etc) all datasets required
+        for a particular ``stage`` (fit, validate, test, predict).
+
+        If you have set ``cache_samples``, samples are loaded, this may
+        effectivley release all the associated memory.
+
+
+        Parameters
+        ----------
+
+        stage
+            Name of the stage to which the teardown is applicable.  Can be one of
+            ``fit``, ``validate``, ``test`` or ``predict``.  Each stage
+            typically uses the following data loaders:
+
+            * ``fit``: uses both train and validation datasets
+            * ``validate``: uses only the validation dataset
+            * ``test``: uses only the test dataset
+            * ``predict``: uses only the test dataset
+        """
+
+        self._datasets = {}
+
+    def train_dataloader(self) -> DataLoader:
+        """Returns the train data loader."""
+
+        return torch.utils.data.DataLoader(
+            self._datasets["train"],
+            shuffle=(self._train_sampler is None),
+            batch_size=self._chunk_size,
+            drop_last=self.drop_incomplete_batch,
+            pin_memory=self.pin_memory,
+            sampler=self._train_sampler,
+            **self._dataloader_multiproc,
+        )
+
+    def unshuffled_train_dataloader(self) -> DataLoader:
+        """Returns the train data loader without shuffling."""
+
+        return torch.utils.data.DataLoader(
+            self._datasets["train"],
+            shuffle=False,
+            batch_size=self._chunk_size,
+            drop_last=False,
+            **self._dataloader_multiproc,
+        )
+
+    def val_dataloader(self) -> dict[str, DataLoader]:
+        """Returns the validation data loader(s)"""
+
+        validation_loader_opts = {
+            "batch_size": self._chunk_size,
+            "shuffle": False,
+            "drop_last": self.drop_incomplete_batch,
+            "pin_memory": self.pin_memory,
+        }
+        validation_loader_opts.update(self._dataloader_multiproc)
+
+        return {
+            k: torch.utils.data.DataLoader(
+                self._datasets[k], **validation_loader_opts
+            )
+            for k in self._val_dataset_keys()
+        }
+
+    def test_dataloader(self) -> dict[str, DataLoader]:
+        """Returns the test data loader(s)"""
+
+        return dict(
+            test=torch.utils.data.DataLoader(
+                self._datasets["test"],
+                batch_size=self._chunk_size,
+                shuffle=False,
+                drop_last=self.drop_incomplete_batch,
+                pin_memory=self.pin_memory,
+                **self._dataloader_multiproc,
+            )
+        )
+
+    def predict_dataloader(self) -> dict[str, DataLoader]:
+        """Returns the prediction data loader(s)"""
+
+        return self.test_dataloader()
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
deleted file mode 100644
index 15dc32a96f6662b5ba8afbcdf06632b142840f79..0000000000000000000000000000000000000000
--- a/src/ptbench/data/dataset.py
+++ /dev/null
@@ -1,586 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-import csv
-import json
-import logging
-import os
-import pathlib
-
-import torch
-
-from tqdm import tqdm
-
-logger = logging.getLogger(__name__)
-
-
-class JSONProtocol:
-    """Generic multi-protocol/subset filelist dataset that yields samples.
-
-    To create a new dataset, you need to provide one or more JSON formatted
-    filelists (one per protocol) with the following contents:
-
-    .. code-block:: json
-
-       {
-           "subset1": [
-               [
-                   "value1",
-                   "value2",
-                   "value3"
-               ],
-               [
-                   "value4",
-                   "value5",
-                   "value6"
-               ]
-           ],
-           "subset2": [
-           ]
-       }
-
-    Your dataset many contain any number of subsets, but all sample entries
-    must contain the same number of fields.
-
-
-    Parameters
-    ----------
-
-    protocols : list, dict
-        Paths to one or more JSON formatted files containing the various
-        protocols to be recognized by this dataset, or a dictionary, mapping
-        protocol names to paths (or opened file objects) of CSV files.
-        Internally, we save a dictionary where keys default to the basename of
-        paths (list input).
-
-    fieldnames : list, tuple
-        An iterable over the field names (strings) to assign to each entry in
-        the JSON file.  It should have as many items as fields in each entry of
-        the JSON file.
-
-    loader : object
-        A function that receives as input, a context dictionary (with at least
-        a "protocol" and "subset" keys indicating which protocol and subset are
-        being served), and a dictionary with ``{fieldname: value}`` entries,
-        and returns an object with at least 2 attributes:
-
-        * ``key``: which must be a unique string for every sample across
-          subsets in a protocol, and
-        * ``data``: which contains the data associated witht this sample
-    """
-
-    def __init__(self, protocols, fieldnames):
-        if isinstance(protocols, dict):
-            self._protocols = protocols
-        else:
-            self._protocols = {
-                os.path.basename(
-                    str(k).replace("".join(pathlib.Path(k).suffixes), "")
-                ): k
-                for k in protocols
-            }
-        self.fieldnames = fieldnames
-
-    def check(self, limit=0):
-        """For each protocol, check if all data can be correctly accessed.
-
-        This function assumes each sample has a ``data`` and a ``key``
-        attribute.  The ``key`` attribute should be a string, or representable
-        as such.
-
-
-        Parameters
-        ----------
-
-        limit : int
-            Maximum number of samples to check (in each protocol/subset
-            combination) in this dataset.  If set to zero, then check
-            everything.
-
-
-        Returns
-        -------
-
-        errors : int
-            Number of errors found
-        """
-        logger.info("Checking dataset...")
-        errors = 0
-        for proto in self._protocols:
-            logger.info(f"Checking protocol '{proto}'...")
-            for name, samples in self.subsets(proto).items():
-                logger.info(f"Checking subset '{name}'...")
-                if limit:
-                    logger.info(f"Checking at most first '{limit}' samples...")
-                    samples = samples[:limit]
-                for pos, sample in enumerate(samples):
-                    try:
-                        sample.data  # may trigger data loading
-                        logger.info(f"{sample.key}: OK")
-                    except Exception as e:
-                        logger.error(
-                            f"Found error loading entry {pos} in subset {name} "
-                            f"of protocol {proto} from file "
-                            f"'{self._protocols[proto]}': {e}"
-                        )
-                        errors += 1
-        return errors
-
-    def subsets(self, protocol):
-        """Returns all subsets in a protocol.
-
-        This method will load JSON information for a given protocol and return
-        all subsets of the given protocol after converting each entry through
-        the loader function.
-
-        Parameters
-        ----------
-
-        protocol : str
-            Name of the protocol data to load
-
-
-        Returns
-        -------
-
-        subsets : dict
-            A dictionary mapping subset names to lists of objects (respecting
-            the ``key``, ``data`` interface).
-        """
-        fileobj = self._protocols[protocol]
-        if isinstance(fileobj, (str, bytes, pathlib.Path)):
-            if str(fileobj).endswith(".bz2"):
-                import bz2
-
-                with bz2.open(self._protocols[protocol]) as f:
-                    data = json.load(f)
-            else:
-                with open(self._protocols[protocol]) as f:
-                    data = json.load(f)
-        else:
-            data = json.load(fileobj)
-            fileobj.seek(0)
-
-        retval = {}
-        for subset, samples in data.items():
-            logger.info(f"Loading subset {subset} samples.")
-
-            retval[subset] = [
-                dict(zip(self.fieldnames, k))
-                for n, k in enumerate(tqdm(samples))
-            ]
-
-        return retval
-
-
-class CSVDataset:
-    """Generic multi-subset filelist dataset that yields samples.
-
-    To create a new dataset, you only need to provide a CSV formatted filelist
-    using any separator (e.g. comma, space, semi-colon) with the following
-    information:
-
-    .. code-block:: text
-
-       value1,value2,value3
-       value4,value5,value6
-       ...
-
-    Notice that all rows must have the same number of entries.
-
-    Parameters
-    ----------
-
-    subsets : list, dict
-        Paths to one or more CSV formatted files containing the various subsets
-        to be recognized by this dataset, or a dictionary, mapping subset names
-        to paths (or opened file objects) of CSV files.  Internally, we save a
-        dictionary where keys default to the basename of paths (list input).
-
-    fieldnames : list, tuple
-        An iterable over the field names (strings) to assign to each column in
-        the CSV file.  It should have as many items as fields in each row of
-        the CSV file(s).
-
-    loader : object
-        A function that receives as input, a context dictionary (with, at
-        least, a "subset" key indicating which subset is being served), and a
-        dictionary with ``{key: path}`` entries, and returns a dictionary with
-        the loaded data.
-    """
-
-    def __init__(self, subsets, fieldnames, loader):
-        if isinstance(subsets, dict):
-            self._subsets = subsets
-        else:
-            self._subsets = {
-                os.path.basename(
-                    str(k).replace("".join(pathlib.Path(k).suffixes), "")
-                ): k
-                for k in subsets
-            }
-        self.fieldnames = fieldnames
-        self._loader = loader
-
-    def check(self, limit=0):
-        """For each subset, check if all data can be correctly accessed.
-
-        This function assumes each sample has a ``data`` and a ``key``
-        attribute.  The ``key`` attribute should be a string, or representable
-        as such.
-
-
-        Parameters
-        ----------
-
-        limit : int
-            Maximum number of samples to check (in each protocol/subset
-            combination) in this dataset.  If set to zero, then check
-            everything.
-
-
-        Returns
-        -------
-
-        errors : int
-            Number of errors found
-        """
-        logger.info("Checking dataset...")
-        errors = 0
-        for name in self._subsets.keys():
-            logger.info(f"Checking subset '{name}'...")
-            samples = self.samples(name)
-            if limit:
-                logger.info(f"Checking at most first '{limit}' samples...")
-                samples = samples[:limit]
-            for pos, sample in enumerate(samples):
-                try:
-                    sample.data  # may trigger data loading
-                    logger.info(f"{sample.key}: OK")
-                except Exception as e:
-                    logger.error(
-                        f"Found error loading entry {pos} in subset {name} "
-                        f"from file '{self._subsets[name]}': {e}"
-                    )
-                    errors += 1
-        return errors
-
-    def subsets(self):
-        """Returns all available subsets at once.
-
-        Returns
-        -------
-
-        subsets : dict
-            A dictionary mapping subset names to lists of objects (respecting
-            the ``key``, ``data`` interface).
-        """
-        return {k: self.samples(k) for k in self._subsets.keys()}
-
-    def samples(self, subset):
-        """Returns all samples in a subset.
-
-        This method will load CSV information for a given subset and return
-        all samples of the given subset after passing each entry through the
-        loading function.
-
-
-        Parameters
-        ----------
-
-        subset : str
-            Name of the subset data to load
-
-
-        Returns
-        -------
-
-        subset : list
-            A lists of objects (respecting the ``key``, ``data`` interface).
-        """
-        fileobj = self._subsets[subset]
-        if isinstance(fileobj, (str, bytes, pathlib.Path)):
-            with open(self._subsets[subset], newline="") as f:
-                cf = csv.reader(f)
-                samples = [k for k in cf]
-        else:
-            cf = csv.reader(fileobj)
-            samples = [k for k in cf]
-            fileobj.seek(0)
-
-        return [
-            self._loader(
-                dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
-            )
-            for n, k in enumerate(samples)
-        ]
-
-
-class CachedDataset(torch.utils.data.Dataset):
-    def __init__(
-        self, json_protocol, protocol, subset, raw_data_loader, transforms
-    ):
-        self.json_protocol = json_protocol
-        self.subset = subset
-        self.raw_data_loader = raw_data_loader
-        self.transforms = transforms
-
-        self._samples = json_protocol.subsets(protocol)[self.subset]
-        # Dict entry with relative path to files, used during prediction
-        for s in self._samples:
-            s["name"] = s["data"]
-
-        logger.info(f"Caching {self.subset} samples")
-        for sample in tqdm(self._samples):
-            sample["data"] = self.raw_data_loader(sample["data"])
-
-    def __getitem__(self, idx):
-        sample = self._samples[idx].copy()
-        sample["data"] = self.transforms(sample["data"])
-        return sample
-
-    def __len__(self):
-        return len(self._samples)
-
-
-class RuntimeDataset(torch.utils.data.Dataset):
-    def __init__(
-        self, json_protocol, protocol, subset, raw_data_loader, transforms
-    ):
-        self.json_protocol = json_protocol
-        self.subset = subset
-        self.raw_data_loader = raw_data_loader
-        self.transforms = transforms
-
-        self._samples = json_protocol.subsets(protocol)[self.subset]
-        # Dict entry with relative path to files
-        for s in self._samples:
-            s["name"] = s["data"]
-
-    def __getitem__(self, idx):
-        sample = self._samples[idx].copy()
-        sample["data"] = self.transforms(self.raw_data_loader(sample["data"]))
-        return sample
-
-    def __len__(self):
-        return len(self._samples)
-
-
-def get_samples_weights(dataset):
-    """Compute the weights of all the samples of the dataset to balance it
-    using the sampler of the dataloader.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the weights to balance each class in the dataset and the
-    datasets themselves if we have a ConcatDataset.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    samples_weights : :py:class:`torch.Tensor`
-        the weights for all the samples in the dataset given as input
-    """
-    samples_weights = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            # Weighting only for binary labels
-            if isinstance(ds._samples[0]["label"], int):
-                # Groundtruth
-                targets = []
-                for s in ds._samples:
-                    targets.append(s["label"])
-                targets = torch.tensor(targets)
-
-                # Count number of samples per class
-                class_sample_count = torch.tensor(
-                    [
-                        (targets == t).sum()
-                        for t in torch.unique(targets, sorted=True)
-                    ]
-                )
-
-                weight = 1.0 / class_sample_count.float()
-
-                samples_weights.append(
-                    torch.tensor([weight[t] for t in targets])
-                )
-
-            else:
-                # We only weight to sample equally from each dataset
-                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
-
-        # Concatenate sample weights from all the datasets
-        samples_weights = torch.cat(samples_weights)
-
-    else:
-        # Weighting only for binary labels
-        if isinstance(dataset._samples[0]["label"], int):
-            # Groundtruth
-            targets = []
-            for s in dataset._samples:
-                targets.append(s["label"])
-            targets = torch.tensor(targets)
-
-            # Count number of samples per class
-            class_sample_count = torch.tensor(
-                [
-                    (targets == t).sum()
-                    for t in torch.unique(targets, sorted=True)
-                ]
-            )
-
-            weight = 1.0 / class_sample_count.float()
-
-            samples_weights = torch.tensor([weight[t] for t in targets])
-
-        else:
-            # Equal weights for non-binary labels
-            samples_weights = torch.ones(len(dataset._samples))
-
-    return samples_weights
-
-
-def get_positive_weights(dataset):
-    """Compute the positive weights of each class of the dataset to balance the
-    BCEWithLogitsLoss criterion.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the positive weights of each class to use them to have
-    a balanced loss.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    positive_weights : :py:class:`torch.Tensor`
-        the positive weight of each class in the dataset given as input
-    """
-    targets = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            for s in ds._samples:
-                targets.append(s["label"])
-
-    else:
-        for s in dataset._samples:
-            targets.append(s["label"])
-
-    targets = torch.tensor(targets)
-
-    # Binary labels
-    if len(list(targets.shape)) == 1:
-        class_sample_count = [
-            float((targets == t).sum().item())
-            for t in torch.unique(targets, sorted=True)
-        ]
-
-        # Divide negatives by positives
-        positive_weights = torch.tensor(
-            [class_sample_count[0] / class_sample_count[1]]
-        ).reshape(-1)
-
-    # Multiclass labels
-    else:
-        class_sample_count = torch.sum(targets, dim=0)
-        negative_class_sample_count = (
-            torch.full((targets.size()[1],), float(targets.size()[0]))
-            - class_sample_count
-        )
-
-        positive_weights = negative_class_sample_count / (
-            class_sample_count + negative_class_sample_count
-        )
-
-    return positive_weights
-
-
-def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
-    from torch.nn import BCEWithLogitsLoss
-
-    datamodule.prepare_data()
-    datamodule.setup(stage="fit")
-
-    train_dataset = datamodule.train_dataset
-    validation_dataset = datamodule.validation_dataset
-
-    # Redefine a weighted criterion if possible
-    if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
-        positive_weights = get_positive_weights(train_dataset)
-        model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
-    else:
-        logger.warning("Weighted criterion not supported")
-
-    if validation_dataset is not None:
-        # Redefine a weighted valid criterion if possible
-        if (
-            isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
-            or criterion_valid is None
-        ):
-            positive_weights = get_positive_weights(validation_dataset)
-            model.hparams.criterion_valid = BCEWithLogitsLoss(
-                pos_weight=positive_weights
-            )
-        else:
-            logger.warning("Weighted valid criterion not supported")
-
-
-def normalize_data(normalization, model, datamodule):
-    from torch.utils.data import DataLoader
-
-    datamodule.prepare_data()
-    datamodule.setup(stage="fit")
-
-    train_dataset = datamodule.train_dataset
-
-    # Create z-normalization model layer if needed
-    if normalization == "imagenet":
-        model.normalizer.set_mean_std(
-            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
-        )
-        logger.info("Z-normalization with ImageNet mean and std")
-    elif normalization == "current":
-        # Compute mean/std of current train subset
-        temp_dl = DataLoader(
-            dataset=train_dataset, batch_size=len(train_dataset)
-        )
-
-        data = next(iter(temp_dl))
-        mean = data[1].mean(dim=[0, 2, 3])
-        std = data[1].std(dim=[0, 2, 3])
-
-        model.normalizer.set_mean_std(mean, std)
-
-        # Format mean and std for logging
-        mean = str(
-            [
-                round(x, 3)
-                for x in ((mean * 10**3).round() / (10**3)).tolist()
-            ]
-        )
-        std = str(
-            [
-                round(x, 3)
-                for x in ((std * 10**3).round() / (10**3)).tolist()
-            ]
-        )
-        logger.info(f"Z-normalization with mean {mean} and std {std}")
diff --git a/src/ptbench/data/image_utils.py b/src/ptbench/data/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac31b9ce7fbce85fb688b394c99d591b83049f7f
--- /dev/null
+++ b/src/ptbench/data/image_utils.py
@@ -0,0 +1,100 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+
+"""Data loading code."""
+
+import pathlib
+
+import numpy
+import PIL.Image
+
+
+class SingleAutoLevel16to8:
+    """Converts a 16-bit image to 8-bit representation using "auto-level".
+
+    This transform assumes that the input image is gray-scaled.
+
+    To auto-level, we calculate the maximum and the minimum of the image, and
+    consider such a range should be mapped to the [0,255] range of the
+    destination image.
+    """
+
+    def __call__(self, img):
+        imin, imax = img.getextrema()
+        irange = imax - imin
+        return PIL.Image.fromarray(
+            numpy.round(
+                255.0 * (numpy.array(img).astype(float) - imin) / irange
+            ).astype("uint8"),
+        ).convert("L")
+
+
+class RemoveBlackBorders:
+    """Remove black borders of CXR."""
+
+    def __init__(self, threshold=0):
+        self.threshold = threshold
+
+    def __call__(self, img):
+        img = numpy.asarray(img)
+        mask = numpy.asarray(img) > self.threshold
+        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
+
+
+def load_pil(path: str | pathlib.Path) -> PIL.Image.Image:
+    """Loads a sample data.
+
+    Parameters
+    ----------
+
+    path
+        The full path leading to the image to be loaded
+
+
+    Returns
+    -------
+
+    image
+        A PIL image
+    """
+    return PIL.Image.open(path)
+
+
+def load_pil_baw(path: str | pathlib.Path) -> PIL.Image.Image:
+    """Loads a sample data.
+
+    Parameters
+    ----------
+
+    path
+        The full path leading to the image to be loaded
+
+
+    Returns
+    -------
+
+    image
+        A PIL image in grayscale mode
+    """
+    return load_pil(path).convert("L")
+
+
+def load_pil_rgb(path: str | pathlib.Path) -> PIL.Image.Image:
+    """Loads a sample data.
+
+    Parameters
+    ----------
+
+    path
+        The full path leading to the image to be loaded
+
+
+    Returns
+    -------
+
+    image
+        A PIL image in RGB mode
+    """
+    return load_pil(path).convert("RGB")
diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py
deleted file mode 100644
index d6e86ed06dd04abb230830ebc8b13e2ea9720cfa..0000000000000000000000000000000000000000
--- a/src/ptbench/data/loader.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-
-"""Data loading code."""
-
-import PIL.Image
-
-
-def load_pil(path):
-    """Loads a sample data.
-
-    Parameters
-    ----------
-
-    path : str
-        The full path leading to the image to be loaded
-
-
-    Returns
-    -------
-
-    image : PIL.Image.Image
-        A PIL image
-    """
-    return PIL.Image.open(path)
-
-
-def load_pil_baw(path):
-    """Loads a sample data.
-
-    Parameters
-    ----------
-
-    path : str
-        The full path leading to the image to be loaded
-
-
-    Returns
-    -------
-
-    image : PIL.Image.Image
-        A PIL image in grayscale mode
-    """
-    return load_pil(path).convert("L")
-
-
-def load_pil_rgb(path):
-    """Loads a sample data.
-
-    Parameters
-    ----------
-
-    path : str
-        The full path leading to the image to be loaded
-
-
-    Returns
-    -------
-
-    image : PIL.Image.Image
-        A PIL image in RGB mode
-    """
-    return load_pil(path).convert("RGB")
diff --git a/src/ptbench/data/padchest/__init__.py b/src/ptbench/data/padchest/__init__.py
index af1dd3ec9f8cbaa1b3e32f5195363592602d94ac..f6f39a9ebe98eb68a88374b8d83a70447a4ad995 100644
--- a/src/ptbench/data/padchest/__init__.py
+++ b/src/ptbench/data/padchest/__init__.py
@@ -264,7 +264,7 @@ json_dataset = JSONDataset(
 def _maker(protocol, resize_size=512, cc_size=512, RGB=True):
     import torchvision.transforms as transforms
 
-    from ..transforms import SingleAutoLevel16to8
+    from ..loader import SingleAutoLevel16to8
 
     post_transforms = []
     if not RGB:
diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py
index 54eb632684a83e16ac28481f8499137f594b662a..1645962e8cc00443399dd60b88f017c71824e086 100644
--- a/src/ptbench/data/shenzhen/__init__.py
+++ b/src/ptbench/data/shenzhen/__init__.py
@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems.
 * Reference: [MONTGOMERY-SHENZHEN-2014]_
 * Original resolution (height x width or width x height): 3000 x 3000 or less
 * Split reference: none
-* Protocol ``default``:
-
   * Training samples: 64% of TB and healthy CXR (including labels)
   * Validation samples: 16% of TB and healthy CXR (including labels)
   * Test samples: 20% of TB and healthy CXR (including labels)
 """
 import importlib.resources
-import os
-
-from clapper.logging import setup
-
-from ...utils.rc import load_rc
-from ..loader import load_pil_baw
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
 
 _protocols = [
     importlib.resources.files(__name__).joinpath("default.json.bz2"),
@@ -43,11 +32,3 @@ _protocols = [
     importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
     importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
 ]
-
-_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
-
-
-def _raw_data_loader(img_path):
-    raw_data = load_pil_baw(os.path.join(_datadir, img_path))
-
-    return raw_data
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index b5fb23f59277f1f6511c6707929856cc2357be2f..bfe93f44faaa9df235f357c1cc3a927412f4a011 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -2,27 +2,46 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Shenzhen dataset for TB detection (default protocol)
+"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
 
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.shenzhen` for dataset details
-"""
+See :py:mod:`ptbench.data.shenzhen` for more database details.
+
+This configuration:
+
+* Raw data input (on disk):
+
+  * PNG images (black and white, encoded as color images)
+  * Variable width and height:
 
-from clapper.logging import setup
+    * widths: from 1130 to 3001 pixels
+    * heights: from 948 to 3001 pixels
 
-from ..transforms import ElasticDeformation
-from .utils import ShenzhenDataModule
+* Output image:
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+  * Transforms:
+
+    * Load raw PNG with :py:mod:`PIL`
+    * Remove black borders
+    * Torch resizing(512px, 512px)
+    * Torch center cropping (512px, 512px)
+
+  * Final specifications:
+
+    * Fixed resolution: 512x512 pixels
+    * Color RGB encoding
+"""
 
-protocol_name = "default"
+import importlib.resources
 
-augmentation_transforms = [ElasticDeformation(p=0.8)]
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .loader import RawDataLoader
 
-datamodule = ShenzhenDataModule(
-    protocol="default",
-    model_transforms=[],
-    augmentation_transforms=augmentation_transforms,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "default.json.bz2"
+        )
+    ),
+    raw_data_loader=RawDataLoader(),
 )
diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e983fe7028d858118efb9ba41e560aee8e95845a
--- /dev/null
+++ b/src/ptbench/data/shenzhen/loader.py
@@ -0,0 +1,105 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Shenzhen dataset for computer-aided diagnosis.
+
+The standard digital image database for Tuberculosis is created by the National
+Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
+Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
+out-patient clinics, and were captured as part of the daily routine using
+Philips DR Digital Diagnose systems.
+
+* Reference: [MONTGOMERY-SHENZHEN-2014]_
+* Original resolution (height x width or width x height): 3000 x 3000 or less
+* Split reference: none
+* Protocol ``default``:
+
+  * Training samples: 64% of TB and healthy CXR (including labels)
+  * Validation samples: 16% of TB and healthy CXR (including labels)
+  * Test samples: 20% of TB and healthy CXR (including labels)
+"""
+
+import os
+
+import torchvision.transforms
+
+from ...utils.rc import load_rc
+from ..image_utils import RemoveBlackBorders, load_pil_baw
+from ..typing import RawDataLoader as _BaseRawDataLoader
+from ..typing import Sample
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the Shenzen dataset.
+
+    Attributes
+    ----------
+
+    datadir
+        This variable contains the base directory where the database raw data
+        is stored.
+
+    transform
+        Transforms that are always applied to the loaded raw images.
+    """
+
+    datadir: str
+    transform: torchvision.transforms.Compose
+
+    def __init__(self):
+        self.datadir = load_rc().get(
+            "datadir.shenzhen", os.path.realpath(os.curdir)
+        )
+
+        self.transform = torchvision.transforms.Compose(
+            [
+                RemoveBlackBorders(),
+                torchvision.transforms.Resize(512),
+                torchvision.transforms.CenterCrop(512),
+                torchvision.transforms.ToTensor(),
+            ]
+        )
+
+    def sample(self, sample: tuple[str, int]) -> Sample:
+        """Loads a single image sample from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        sample
+            The sample representation
+        """
+        tensor = self.transform(
+            load_pil_baw(os.path.join(self.datadir, sample[0]))
+        )
+        return tensor, dict(label=sample[1])  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, int]) -> int:
+        """Loads a single image sample label from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        label
+            The integer label associated with the sample
+        """
+        return sample[1]
diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py
index 2d93c07edb4c0824caa8149737ec42783966ad4c..f45f601d6e84a04e4610319d1f27631ff32ee69c 100644
--- a/src/ptbench/data/shenzhen/rgb.py
+++ b/src/ptbench/data/shenzhen/rgb.py
@@ -2,81 +2,40 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""Shenzhen dataset for TB detection (cross validation fold 0, RGB)
+"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
 
-* Split reference: first 80% of TB and healthy CXR for "train", rest for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.shenzhen` for dataset details
-"""
-
-from clapper.logging import setup
-
-from ....data import return_subsets
-from ....data.base_datamodule import BaseDataModule
-from ....data.dataset import JSONProtocol
-from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
+See :py:mod:`ptbench.data.shenzhen` for dataset details.
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        cache_samples=False,
-        multiproc_kwargs=None,
-        data_transforms=[],
-        model_transforms=[],
-        train_transforms=[],
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-        self.cache_samples = cache_samples
-        self.has_setup_fit = False
+This configuration:
+* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
+* augmentations: elastic deformation (probability = 80%)
+* output image resolution: 512x512 pixels
+"""
 
-        self.data_transforms = data_transforms
-        self.model_transforms = model_transforms
-        self.train_transforms = train_transforms
+import importlib.resources
 
-        """[
-            transforms.ToPILImage(),
-            transforms.Lambda(lambda x: x.convert("RGB")),
-            transforms.ToTensor(),
-        ]"""
+from torchvision import transforms
 
-    def setup(self, stage: str):
-        if self.cache_samples:
-            logger.info(
-                "Argument cache_samples set to True. Samples will be loaded in memory."
-            )
-            samples_loader = _cached_loader
-        else:
-            logger.info(
-                "Argument cache_samples set to False. Samples will be loaded at runtime."
-            )
-            samples_loader = _delayed_loader
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from .raw_data_loader import raw_data_loader
 
-        self.json_protocol = JSONProtocol(
-            protocols=_protocols,
-            fieldnames=("data", "label"),
-            loader=samples_loader,
-            post_transforms=self.post_transforms,
+datamodule = CachingDataModule(
+    database_split=JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "default.json.bz2"
         )
-
-        if not self.has_setup_fit and stage == "fit":
-            (
-                self.train_dataset,
-                self.validation_dataset,
-                self.extra_validation_datasets,
-            ) = return_subsets(self.json_protocol, "default", stage)
-            self.has_setup_fit = True
-
-
-datamodule = DefaultModule
+    ),
+    raw_data_loader=raw_data_loader,
+    cache_samples=False,
+    # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
+    model_transforms=[
+        transforms.ToPILImage(),
+        transforms.Lambda(lambda x: x.convert("RGB")),
+        transforms.ToTensor(),
+    ],
+    # batch_size = 1,
+    # batch_chunk_count = 1,
+    # drop_incomplete_batch = False,
+    # parallel = -1,
+)
diff --git a/src/ptbench/data/shenzhen/utils.py b/src/ptbench/data/shenzhen/utils.py
deleted file mode 100644
index 1521b674212feec942e73df081e7b20c19d89e29..0000000000000000000000000000000000000000
--- a/src/ptbench/data/shenzhen/utils.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Shenzhen dataset for computer-aided diagnosis.
-
-The standard digital image database for Tuberculosis is created by the
-National Library of Medicine, Maryland, USA in collaboration with Shenzhen
-No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
-The Chest X-rays are from out-patient clinics, and were captured as part of
-the daily routine using Philips DR Digital Diagnose systems.
-
-* Reference: [MONTGOMERY-SHENZHEN-2014]_
-* Original resolution (height x width or width x height): 3000 x 3000 or less
-* Split reference: none
-* Protocol ``default``:
-
-  * Training samples: 64% of TB and healthy CXR (including labels)
-  * Validation samples: 16% of TB and healthy CXR (including labels)
-  * Test samples: 20% of TB and healthy CXR (including labels)
-"""
-from clapper.logging import setup
-from torchvision import transforms
-
-from ..base_datamodule import BaseDataModule
-from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
-from ..shenzhen import _protocols, _raw_data_loader
-from ..transforms import RemoveBlackBorders
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class ShenzhenDataModule(BaseDataModule):
-    def __init__(
-        self,
-        protocol="default",
-        model_transforms=[],
-        augmentation_transforms=[],
-        batch_size=1,
-        batch_chunk_count=1,
-        drop_incomplete_batch=False,
-        cache_samples=False,
-        parallel=-1,
-    ):
-        super().__init__(
-            batch_size=batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            batch_chunk_count=batch_chunk_count,
-            parallel=parallel,
-        )
-
-        self._cache_samples = cache_samples
-        self._has_setup_fit = False
-        self._has_setup_predict = False
-        self._protocol = protocol
-
-        self.raw_data_transforms = [
-            RemoveBlackBorders(),
-            transforms.Resize(512),
-            transforms.CenterCrop(512),
-            transforms.ToTensor(),
-        ]
-
-        self.model_transforms = model_transforms
-
-        self.augmentation_transforms = augmentation_transforms
-
-    def setup(self, stage: str):
-        json_protocol = JSONProtocol(
-            protocols=_protocols,
-            fieldnames=("data", "label"),
-        )
-
-        if self._cache_samples:
-            dataset = CachedDataset
-        else:
-            dataset = RuntimeDataset
-
-        if not self._has_setup_fit and stage == "fit":
-            self.train_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "train",
-                _raw_data_loader,
-                self._build_transforms(is_train=True),
-            )
-
-            self.validation_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "validation",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-
-            self._has_setup_fit = True
-
-        if not self._has_setup_predict and stage == "predict":
-            self.train_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "train",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-            self.validation_dataset = dataset(
-                json_protocol,
-                self._protocol,
-                "validation",
-                _raw_data_loader,
-                self._build_transforms(is_train=False),
-            )
-
-            self._has_setup_predict = True
diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py
new file mode 100644
index 0000000000000000000000000000000000000000..78fe7e33886cd47265d0a7f217bcb55f799c9fa4
--- /dev/null
+++ b/src/ptbench/data/split.py
@@ -0,0 +1,252 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import csv
+import importlib.abc
+import json
+import logging
+import pathlib
+import typing
+
+import torch
+
+from .typing import DatabaseSplit, RawDataLoader
+
+logger = logging.getLogger(__name__)
+
+
+class JSONDatabaseSplit(DatabaseSplit):
+    """Defines a loader that understands a database split (train, test, etc) in
+    JSON format.
+
+    To create a new database split, you need to provide a JSON formatted
+    dictionary in a file, with contents similar to the following:
+
+    .. code-block:: json
+
+       {
+           "subset1": [
+               [
+                   "sample1-data1",
+                   "sample1-data2",
+                   "sample1-data3",
+               ],
+               [
+                   "sample2-data1",
+                   "sample2-data2",
+                   "sample2-data3",
+               ]
+           ],
+           "subset2": [
+               [
+                   "sample42-data1",
+                   "sample42-data2",
+                   "sample42-data3",
+               ],
+           ]
+       }
+
+    Your database split many contain any number of subsets (dictionary keys).
+    For simplicity, we recommend all sample entries are formatted similarly so
+    that raw-data-loading is simplified.  Use the function
+    :py:func:`check_database_split_loading` to test raw data loading and fine
+    tune the dataset split, or its loading.
+
+    Objects of this class behave like a dictionary in which keys are subset
+    names in the split, and values represent samples data and meta-data.
+
+
+    Parameters
+    ----------
+
+    path
+        Absolute path to a JSON formatted file containing the database split to be
+        recognized by this object.
+    """
+
+    def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
+        if isinstance(path, str):
+            path = pathlib.Path(path)
+        self.path = path
+        self.subsets = self._load_split_from_disk()
+
+    def _load_split_from_disk(self) -> DatabaseSplit:
+        """Loads all subsets in a split from its file system representation.
+
+        This method will load JSON information for the current split and return
+        all subsets of the given split after converting each entry through the
+        loader function.
+
+
+        Returns
+        -------
+
+        subsets : dict
+            A dictionary mapping subset names to lists of JSON objects
+        """
+
+        if str(self.path).endswith(".bz2"):
+            logger.debug(f"Loading database split from {str(self.path)}...")
+            with __import__("bz2").open(self.path) as f:
+                return json.load(f)
+        else:
+            with self.path.open() as f:
+                return json.load(f)
+
+    def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
+        """Accesses subset ``key`` from this split."""
+        return self.subsets[key]
+
+    def __iter__(self):
+        """Iterates over the subsets."""
+        return iter(self.subsets)
+
+    def __len__(self) -> int:
+        """How many subsets we currently have."""
+        return len(self.subsets)
+
+
+class CSVDatabaseSplit(DatabaseSplit):
+    """Defines a loader that understands a database split (train, test, etc) in
+    CSV format.
+
+    To create a new database split, you need to provide one or more CSV
+    formatted files, each representing a subset of this split, containing the
+    sample data (one per row).  Example:
+
+    Inside the directory ``my-split/``, one can file files ``train.csv``,
+    ``validation.csv``, and ``test.csv``.  Each file has a structure similar to
+    the following:
+
+    .. code-block:: text
+
+       sample1-value1,sample1-value2,sample1-value3
+       sample2-value1,sample2-value2,sample2-value3
+       ...
+
+    Each file in the provided directory defines the subset name on the split.
+    So, the file ``train.csv`` will contain the data from the ``train`` subset,
+    and so on.
+
+    Objects of this class behave like a dictionary in which keys are subset
+    names in the split, and values represent samples data and meta-data.
+
+
+    Parameters
+    ----------
+
+    directory
+        Absolute path to a directory containing the database split layed down
+        as a set of CSV files, one per subset.
+    """
+
+    def __init__(
+        self, directory: pathlib.Path | str | importlib.abc.Traversable
+    ):
+        if isinstance(directory, str):
+            directory = pathlib.Path(directory)
+        assert (
+            directory.is_dir()
+        ), f"`{str(directory)}` is not a valid directory"
+        self.directory = directory
+        self.subsets = self._load_split_from_disk()
+
+    def _load_split_from_disk(self) -> DatabaseSplit:
+        """Loads all subsets in a split from its file system representation.
+
+        This method will load CSV information for the current split and return all
+        subsets of the given split after converting each entry through the
+        loader function.
+
+
+        Returns
+        -------
+
+        subsets : dict
+            A dictionary mapping subset names to lists of JSON objects
+        """
+
+        retval: DatabaseSplit = {}
+        for subset in self.directory.iterdir():
+            if str(subset).endswith(".csv.bz2"):
+                logger.debug(f"Loading database split from {subset}...")
+                with __import__("bz2").open(subset) as f:
+                    reader = csv.reader(f)
+                    retval[subset.name[: -len(".csv.bz2")]] = [
+                        k for k in reader
+                    ]
+            elif str(subset).endswith(".csv"):
+                with subset.open() as f:
+                    reader = csv.reader(f)
+                    retval[subset.name[: -len(".csv")]] = [k for k in reader]
+            else:
+                logger.debug(
+                    f"Ignoring file {subset} in CSVDatabaseSplit readout"
+                )
+        return retval
+
+    def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
+        """Accesses subset ``key`` from this split."""
+        return self.subsets[key]
+
+    def __iter__(self):
+        """Iterates over the subsets."""
+        return iter(self.subsets)
+
+    def __len__(self) -> int:
+        """How many subsets we currently have."""
+        return len(self.subsets)
+
+
+def check_database_split_loading(
+    database_split: DatabaseSplit,
+    loader: RawDataLoader,
+    limit: int = 0,
+) -> int:
+    """For each subset in the split, check if all data can be correctly loaded
+    using the provided loader function.
+
+    This function will return the number of errors loading samples, and will
+    log more detailed information to the logging stream.
+
+
+    Parameters
+    ----------
+
+    database_split
+        A mapping that, contains the database split.  Each key represents the
+        name of a subset in the split.  Each value is a (potentially complex)
+        object that represents a single sample.
+
+    loader
+        A loader object that knows how to handle full-samples or just labels.
+
+    limit
+        Maximum number of samples to check (in each split/subset
+        combination) in this dataset.  If set to zero, then check
+        everything.
+
+
+    Returns
+    -------
+
+    errors
+        Number of errors found
+    """
+    logger.info(
+        "Checking if can load all samples in all subsets of this split..."
+    )
+    errors = 0
+    for subset in database_split.keys():
+        samples = subset if not limit else subset[:limit]
+        for pos, sample in enumerate(samples):
+            try:
+                data, _ = loader.sample(sample)
+                assert isinstance(data, torch.Tensor)
+            except Exception as e:
+                logger.info(
+                    f"Found error loading entry {pos} in subset `{subset}`: {e}"
+                )
+                errors += 1
+    return errors
diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index 6d3c17d243a4987f889bae8da0bb6a21ce73459b..ad516e194570dca2b5794959f816a0e717733a9d 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -22,39 +22,6 @@ from scipy.ndimage import gaussian_filter, map_coordinates
 from torchvision import transforms
 
 
-class SingleAutoLevel16to8:
-    """Converts a 16-bit image to 8-bit representation using "auto-level".
-
-    This transform assumes that the input image is gray-scaled.
-
-    To auto-level, we calculate the maximum and the minimum of the
-    image, and
-    consider such a range should be mapped to the [0,255] range of the
-    destination image.
-    """
-
-    def __call__(self, img):
-        imin, imax = img.getextrema()
-        irange = imax - imin
-        return PIL.Image.fromarray(
-            numpy.round(
-                255.0 * (numpy.array(img).astype(float) - imin) / irange
-            ).astype("uint8"),
-        ).convert("L")
-
-
-class RemoveBlackBorders:
-    """Remove black borders of CXR."""
-
-    def __init__(self, threshold=0):
-        self.threshold = threshold
-
-    def __call__(self, img):
-        img = numpy.asarray(img)
-        mask = numpy.asarray(img) > self.threshold
-        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
-
-
 class ElasticDeformation:
     """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
 
@@ -68,7 +35,7 @@ class ElasticDeformation:
         spline_order=1,
         mode="nearest",
         random_state=numpy.random,
-        p=1,
+        p=1.0,
     ):
         self.alpha = alpha
         self.sigma = sigma
@@ -79,13 +46,15 @@ class ElasticDeformation:
 
     def __call__(self, img):
         if random.random() < self.p:
-            img = transforms.ToPILImage()(img)
+            assert img.ndim == 3
 
+            # Input tensor is of shape C x H x W
+            # If the tensor only contains one channel, this conversion results in H x W.
+            # With 3 channels, we get H x W x C
+            img = transforms.ToPILImage()(img)
             img = numpy.asarray(img)
 
-            assert img.ndim == 2
-
-            shape = img.shape
+            shape = img.shape[:2]
 
             dx = (
                 gaussian_filter(
@@ -114,9 +83,22 @@ class ElasticDeformation:
                 numpy.reshape(y + dy, (-1, 1)),
             ]
             result = numpy.empty_like(img)
-            result[:, :] = map_coordinates(
-                img[:, :], indices, order=self.spline_order, mode=self.mode
-            ).reshape(shape)
+
+            if img.ndim == 2:
+                result[:, :] = map_coordinates(
+                    img[:, :], indices, order=self.spline_order, mode=self.mode
+                ).reshape(shape)
+
+            else:
+                for i in range(img.shape[2]):
+                    result[:, :, i] = map_coordinates(
+                        img[:, :, i],
+                        indices,
+                        order=self.spline_order,
+                        mode=self.mode,
+                    ).reshape(shape)
+
             return transforms.ToTensor()(PIL.Image.fromarray(result))
+
         else:
             return img
diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..344c1294df6777f88acdde23dec51d40fc51e31e
--- /dev/null
+++ b/src/ptbench/data/typing.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Defines most common types used in code."""
+
+import typing
+
+import torch
+import torch.utils.data
+
+Sample = tuple[torch.Tensor, typing.Mapping[str, typing.Any]]
+"""Definition of a sample.
+
+First parameter
+    The actual data that is input to the model
+
+Second parameter
+    A dictionary containing a named set of meta-data.  One the most common is
+    the ``label`` entry.
+"""
+
+
+class RawDataLoader:
+    """A loader object can load samples and labels from storage."""
+
+    def sample(self, _: typing.Any) -> Sample:
+        """Loads whole samples from media."""
+        raise NotImplementedError("You must implement the `sample()` method")
+
+    def label(self, k: typing.Any) -> int:
+        """Loads only sample label from media.
+
+        If you do not override this implementation, then, by default,
+        this method will call :py:meth:`sample` to load the whole sample
+        and extract the label.
+        """
+        return self.sample(k)[1]["label"]
+
+
+Transform = typing.Callable[[torch.Tensor], torch.Tensor]
+"""A callable, that transforms tensors into (other) tensors.
+
+Typically used in data-processing pipelines inside pytorch.
+"""
+
+TransformSequence = typing.Sequence[Transform]
+"""A sequence of transforms."""
+
+DatabaseSplit = dict[str, typing.Sequence[typing.Any]]
+"""The definition of a database script.
+
+A database script maps subset names to sequences of objects that,
+through RawDataLoader's eventually become Samples in the processing
+pipeline.
+"""
+
+
+class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
+    """Our own definition of a pytorch Dataset, with interesting properties.
+
+    We iterate over Sample objects in this case.  Our datasets always
+    provide a dunder len method.
+    """
+
+    def labels(self) -> list[int]:
+        """Returns the integer labels for all samples in the dataset."""
+        raise NotImplementedError("You must implement the `labels()` method")
+
+
+DataLoader = torch.utils.data.DataLoader[Sample]
+"""Our own augmentation definition of a pytorch DataLoader.
+
+We iterate over Sample objects in this case.
+"""
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index d0ac43f98e21b8ce6803797d6a1fde38c6302660..350140a8516ddef43e89323e65b746ca7c479182 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -1,155 +1,403 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
 import csv
+import logging
 import os
+import pathlib
 import time
+import typing
 
-from collections import defaultdict
+import lightning.pytorch
+import lightning.pytorch.callbacks
+import torch
 
-import numpy
+from ..utils.resources import ResourceMonitor
 
-from lightning.pytorch import Callback
-from lightning.pytorch.callbacks import BasePredictionWriter
+logger = logging.getLogger(__name__)
 
 
-# This ensures CSVLogger logs training and evaluation metrics on the same line
-# CSVLogger only accepts numerical values, not strings
-class LoggingCallback(Callback):
-    """Lightning callback to log various training metrics and device
-    information."""
+class LoggingCallback(lightning.pytorch.Callback):
+    """Callback to log various training metrics and device information.
 
-    def __init__(self, resource_monitor):
-        super().__init__()
-        self.training_loss = []
-        self.validation_loss = []
-        self.extra_validation_loss = defaultdict(list)
-        self.start_training_time = 0
-        self.start_epoch_time = 0
+    It ensures CSVLogger logs training and evaluation metrics on the same line
+    Note that a CSVLogger only accepts numerical values, and not strings.
 
-        self.resource_monitor = resource_monitor
-        self.max_queue_retries = 2
 
-    def on_train_start(self, trainer, pl_module):
-        self.start_training_time = time.time()
+    Parameters
+    ----------
 
-    def on_train_epoch_start(self, trainer, pl_module):
-        self.start_epoch_time = time.time()
+    resource_monitor
+        A monitor that watches resource usage (CPU/GPU) in a separate process
+        and totally asynchronously with the code execution.
+    """
 
-    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
-        self.training_loss.append(outputs["loss"].item())
+    def __init__(self, resource_monitor: ResourceMonitor):
+        super().__init__()
 
-    def on_validation_batch_end(
-        self, trainer, pl_module, outputs, batch, batch_idx
+        # lists of number of samples/batch and average losses
+        # - we use this later to compute overall epoch losses
+        self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
+        self._validation_epoch_loss: dict[
+            int, tuple[list[int], list[float]]
+        ] = {}
+
+        # timers
+        self._start_training_time = 0.0
+        self._start_training_epoch_time = 0.0
+        self._start_validation_epoch_time = 0.0
+
+        # log accumulators for a single flush at each training cycle
+        self._to_log: dict[str, float] = {}
+
+        # helpers for CPU and GPU utilisation
+        self._resource_monitor = resource_monitor
+        self._max_queue_retries = 2
+
+    def on_train_start(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
     ):
-        self.validation_loss.append(outputs["validation_loss"].item())
+        """Callback to be executed **before** the whole training starts.
 
-        if len(outputs) > 1:
-            extra_validation_keys = outputs.keys().remove("validation_loss")
-            for extra_validation_loss_key in extra_validation_keys:
-                self.extra_validation_loss[extra_validation_loss_key].append(
-                    outputs[extra_validation_loss_key]
-                )
+        This method is executed whenever you *start* training a module.
 
-    def on_validation_epoch_end(self, trainer, pl_module):
-        self.resource_monitor.trigger_summary()
 
-        self.epoch_time = time.time() - self.start_epoch_time
-        eta_seconds = self.epoch_time * (
-            trainer.max_epochs - trainer.current_epoch
-        )
-        current_time = time.time() - self.start_training_time
+        Parameters
+        ---------
 
-        def _compute_batch_loss(losses, num_chunks):
-            # When accumulating gradients, partial losses need to be summed per batch before averaging
-            if num_chunks != 1:
-                # The loss we get is scaled by the number of accumulation steps
-                losses = numpy.multiply(losses, num_chunks)
+        trainer
+            The Lightning trainer object
 
-                if len(losses) % num_chunks > 0:
-                    num_splits = (len(losses) // num_chunks) + 1
-                else:
-                    num_splits = len(losses) // num_chunks
+        pl_module
+            The lightning module that is being trained
+        """
+        self._start_training_time = time.time()
 
-                batched_losses = numpy.array_split(losses, num_splits)
+    def on_train_epoch_start(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ) -> None:
+        """Callback to be executed **before** every training batch starts.
 
-                summed_batch_losses = []
+        This method is executed whenever a training batch starts.  Presumably,
+        batches happen as often as possible.  You want to make this code very
+        fast.  Do not log things to the terminal or the such, or do complicated
+        (lengthy) calculations.
 
-                for b in batched_losses:
-                    summed_batch_losses.append(numpy.average(b))
+        .. warning::
 
-                return summed_batch_losses
+           This is executed **while** you are training.  Be very succint or
+           face the consequences of slow training!
 
-            # No gradient accumulation, we already have the batch losses
-            else:
-                return losses
 
-        # Do not log during sanity check as results are not relevant
-        if not trainer.sanity_checking:
-            # We get partial loses when using gradient accumulation
-            self.training_loss = _compute_batch_loss(
-                self.training_loss, trainer.accumulate_grad_batches
-            )
-            self.validation_loss = _compute_batch_loss(
-                self.validation_loss, trainer.accumulate_grad_batches
-            )
+        Parameters
+        ---------
 
-            self.log("total_time", current_time)
-            self.log("eta", eta_seconds)
-            self.log("loss", numpy.average(self.training_loss))
-            self.log(
-                "learning_rate", pl_module.hparams["optimizer_configs"]["lr"]
-            )
-            self.log("validation_loss", numpy.sum(self.validation_loss))
-
-            if len(self.extra_validation_loss) > 0:
-                for (
-                    extra_valid_loss_key,
-                    extra_valid_loss_values,
-                ) in self.extra_validation_loss.items:
-                    self.log(
-                        extra_valid_loss_key, numpy.sum(extra_valid_loss_values)
-                    )
+        trainer
+            The Lightning trainer object
 
-        queue_retries = 0
-        # In case the resource monitor takes longer to fetch data from the queue, we wait
-        # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue
-        while (
-            self.resource_monitor.data is None
-            and queue_retries < self.max_queue_retries
-        ):
-            queue_retries = queue_retries + 1
-            print(
-                f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s"
+        pl_module
+            The lightning module that is being trained
+        """
+        self._start_training_epoch_time = time.time()
+        self._training_epoch_loss = ([], [])
+
+    def on_train_epoch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ):
+        """Callback to be executed **after** every training epoch ends.
+
+        This method is executed whenever a training epoch ends.  Presumably,
+        epochs happen as often as possible.  You want to make this code
+        relatively fast to avoid significative runtime slow-downs.
+
+
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+
+        # summarizes resource usage since the last checkpoint
+        # clears internal buffers and starts accumulating again.
+        self._resource_monitor.checkpoint()
+
+        # evaluates this training epoch total time, and log it
+        epoch_time = time.time() - self._start_training_epoch_time
+
+        # Compute overall training loss considering batches and sizes
+        # We disconsider accumulate_grad_batches and assume they were all of
+        # the same size.  This way, the average of averages is the overall
+        # average.
+        self._to_log["train_loss"] = torch.mean(
+            torch.tensor(self._training_epoch_loss[0])
+            * torch.tensor(self._training_epoch_loss[1])
+        ).item()
+
+        self._to_log["train_epoch_time"] = epoch_time
+        self._to_log["learning_rate"] = pl_module.optimizers().defaults["lr"]
+
+        metrics = self._resource_monitor.data
+        if metrics is not None:
+            for metric_name, metric_value in metrics.items():
+                self._to_log[f"train_{metric_name}"] = float(metric_value)
+        else:
+            logger.warning(
+                "Unable to fetch monitoring information from "
+                "resource monitor. CPU/GPU utilisation will be "
+                "missing."
             )
-            time.sleep(self.resource_monitor.interval)
 
-        if queue_retries >= self.max_queue_retries:
-            print(
-                f"Unable to fetch monitoring information from queue after {queue_retries} retries"
+        # if no validation dataloaders, complete cycle by the end of the
+        # training epoch, by logging all values to the logger
+        self.on_cycle_end(trainer, pl_module)
+
+    def on_train_batch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+        outputs: typing.Mapping[str, torch.Tensor],
+        batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
+        batch_idx: int,
+    ) -> None:
+        """Callback to be executed **after** every training batch ends.
+
+        This method is executed whenever a training batch ends.  Presumably,
+        batches happen as often as possible.  You want to make this code very
+        fast.  Do not log things to the terminal or the such, or do complicated
+        (lengthy) calculations.
+
+        .. warning::
+
+           This is executed **while** you are training.  Be very succint or
+           face the consequences of slow training!
+
+
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+
+        outputs
+            The outputs of the module's ``training_step``
+
+        batch
+            The data that the training step received
+
+        batch_idx
+            The relative number of the batch
+        """
+        self._training_epoch_loss[0].append(batch[0].shape[0])
+        self._training_epoch_loss[1].append(outputs["loss"].item())
+
+    def on_validation_epoch_start(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ) -> None:
+        """Callback to be executed **before** every validation batch starts.
+
+        This method is executed whenever a validation batch starts.  Presumably,
+        batches happen as often as possible.  You want to make this code very
+        fast.  Do not log things to the terminal or the such, or do complicated
+        (lengthy) calculations.
+
+        .. warning::
+
+           This is executed **while** you are training.  Be very succint or
+           face the consequences of slow training!
+
+
+        Parameters
+        ---------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+        self._start_validation_epoch_time = time.time()
+        self._validation_epoch_loss = {}
+
+    def on_validation_epoch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ) -> None:
+        """Callback to be executed **after** every validation epoch ends.
+
+        This method is executed whenever a validation epoch ends.  Presumably,
+        epochs happen as often as possible.  You want to make this code
+        relatively fast to avoid significative runtime slow-downs.
+
+
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+
+        # summarizes resource usage since the last checkpoint
+        # clears internal buffers and starts accumulating again.
+        self._resource_monitor.checkpoint()
+
+        epoch_time = time.time() - self._start_validation_epoch_time
+        self._to_log["validation_epoch_time"] = epoch_time
+
+        metrics = self._resource_monitor.data
+        if metrics is not None:
+            for metric_name, metric_value in metrics.items():
+                self._to_log[f"validation_{metric_name}"] = float(metric_value)
+        else:
+            logger.warning(
+                "Unable to fetch monitoring information from "
+                "resource monitor. CPU/GPU utilisation will be "
+                "missing."
             )
 
-        assert self.resource_monitor.q.empty()
+        # Compute overall validation losses considering batches and sizes
+        # We disconsider accumulate_grad_batches and assume they were all
+        # of the same size.  This way, the average of averages is the
+        # overall average.
+        for key in sorted(self._validation_epoch_loss.keys()):
+            if key == 0:
+                name = "validation_loss"
+            else:
+                name = f"validation_loss_{key}"
 
-        # Do not log during sanity check as results are not relevant
-        if not trainer.sanity_checking:
-            for metric_name, metric_value in self.resource_monitor.data:
-                self.log(metric_name, float(metric_value))
+            self._to_log[name] = torch.mean(
+                torch.tensor(self._validation_epoch_loss[key][0])
+                * torch.tensor(self._validation_epoch_loss[key][1])
+            ).item()
+
+    def on_validation_batch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+        outputs: torch.Tensor,
+        batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
+        batch_idx: int,
+        dataloader_idx: int = 0,
+    ) -> None:
+        """Callback to be executed **after** every validation batch ends.
+
+        This method is executed whenever a validation batch ends.  Presumably,
+        batches happen as often as possible.  You want to make this code very
+        fast.  Do not log things to the terminal or the such, or do complicated
+        (lengthy) calculations.
 
-        self.resource_monitor.data = None
+        .. warning::
 
-        self.training_loss = []
-        self.validation_loss = []
+           This is executed **while** you are training.  Be very succint or
+           face the consequences of slow training!
 
 
-class PredictionsWriter(BasePredictionWriter):
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+
+        outputs
+            The outputs of the module's ``training_step``
+
+        batch
+            The data that the training step received
+
+        batch_idx
+            The relative number of the batch
+
+        dataloader_idx
+            Index of the dataloader used during validation.  Use this to figure
+            out which dataset was used for this validation epoch.
+        """
+        size, value = self._validation_epoch_loss.setdefault(
+            dataloader_idx, ([], [])
+        )
+        size.append(batch[0].shape[0])
+        value.append(outputs.item())
+
+    def on_cycle_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ) -> None:
+        """Called when the training/validation cycle has ended.
+
+        This function will log all relevant values to the various loggers.  It
+        is supposed to be called by the end of the training cycle (consisting
+        of a training and validation step).
+
+
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+
+        # collect some final time for the whole training cycle
+        # Note: logging should happen at on_validation_end(), but
+        # apparently you can't log from there
+        overall_cycle_time = time.time() - self._start_training_epoch_time
+        self._to_log["train_cycle_time"] = overall_cycle_time
+        self._to_log["total_time"] = time.time() - self._start_training_time
+        self._to_log["eta"] = overall_cycle_time * (
+            trainer.max_epochs - trainer.current_epoch  # type: ignore
+        )
+
+        # Do not log during sanity check as results are not relevant
+        if not trainer.sanity_checking:
+            for k in sorted(self._to_log.keys()):
+                pl_module.log(k, self._to_log[k])
+            self._to_log = {}
+
+
+class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
     """Lightning callback to write predictions to a file."""
 
-    def __init__(self, output_dir, logfile_fields, write_interval):
+    def __init__(
+        self,
+        output_dir: str | pathlib.Path,
+        logfile_fields: typing.Sequence[str],
+        write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"],
+    ):
         super().__init__(write_interval)
         self.output_dir = output_dir
         self.logfile_fields = logfile_fields
 
     def write_on_epoch_end(
-        self, trainer, pl_module, predictions, batch_indices
-    ):
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+        predictions: typing.Sequence[typing.Any],
+        batch_indices: typing.Sequence[typing.Any] | None,
+    ) -> None:
         for dataloader_idx, dataloader_results in enumerate(predictions):
             dataloader_name = list(
                 trainer.datamodule.predict_dataloader().keys()
diff --git a/src/ptbench/engine/device.py b/src/ptbench/engine/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..253bba0d9da3bedca6010b2ad87937b9da4f08e0
--- /dev/null
+++ b/src/ptbench/engine/device.py
@@ -0,0 +1,150 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+import os
+
+import torch
+import torch.backends
+
+logger = logging.getLogger(__name__)
+
+
+def _split_int_list(s: str) -> list[int]:
+    """Splits a list of integers encoded in a string (e.g. "1,2,3") into a
+    Python list of integers (e.g. ``[1, 2, 3]``)."""
+    return [int(k.strip()) for k in s.split(",")]
+
+
+class DeviceManager:
+    """This class is used to manage the Lightning Accelerator and Pytorch
+    Devices.
+
+    It takes the user input, in the form of a string defined by
+    ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``), and can
+    translate to the right incarnation of Pytorch devices or Lightning
+    Accelerators to interface with the various frameworks.
+
+    Instances of this class also manage the environment variable
+    ``$CUDA_VISIBLE_DEVICES`` if necessary.
+
+
+    Parameters
+    ----------
+
+    name
+        The name of the device to use, in the form of a string defined by
+        ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``).  In
+        the specific case of ``cuda``, one can also specify a device to use
+        either by adding ``:N``, where N is the zero-indexed board number on
+        the computer, or by setting the environment variable
+        ``$CUDA_VISIBLE_DEVICES`` with the devices that are usable by the
+        current process.
+    """
+
+    SUPPORTED = ("cpu", "cuda", "mps")
+
+    def __init__(self, name: str):
+        parts = name.split(":", 1)
+        self.device_type = parts[0]
+        self.device_ids: list[int] = []
+        if len(parts) > 1:
+            self.device_ids = _split_int_list(parts[1])
+
+        if self.device_type == "cuda":
+            visible_env = os.environ.get("CUDA_VISIBLE_DEVICES")
+            if visible_env:
+                visible = _split_int_list(visible_env)
+                if self.device_ids and visible != self.device_ids:
+                    logger.warning(
+                        f"${{CUDA_VISIBLE_DEVICES}}={visible} and name={name} "
+                        f"- overriding environment with value set on `name`"
+                    )
+                else:
+                    self.device_ids = visible
+
+            # make sure that it is consistent with the environment
+            if self.device_ids:
+                os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
+                    [str(k) for k in self.device_ids]
+                )
+
+        if self.device_type not in DeviceManager.SUPPORTED:
+            raise RuntimeError(
+                f"Unsupported device type `{self.device_type}`. "
+                f"Supported devices types are `{', '.join(DeviceManager.SUPPORTED)}`"
+            )
+
+        if self.device_ids and self.device_type in ("cpu", "mps"):
+            logger.warning(
+                f"Cannot pin device ids if using cpu or mps backend. "
+                f"Setting `name` to {name} is non-sensical.  Ignoring..."
+            )
+
+        # check if the device_type that was set has support compiled in
+        if self.device_type == "cuda":
+            assert hasattr(torch, "cuda") and torch.cuda.is_available(), (
+                f"User asked for device = `{name}`, but CUDA support is "
+                f"not compiled into pytorch!"
+            )
+
+        if self.device_type == "mps":
+            assert (
+                hasattr(torch.backends, "mps")
+                and torch.backends.mps.is_available()  # type:ignore
+            ), (
+                f"User asked for device = `{name}`, but MPS support is "
+                f"not compiled into pytorch!"
+            )
+
+    def torch_device(self) -> torch.device:
+        """Returns a representation of the torch device to use by default.
+
+        .. warning::
+
+           If a list of devices is set, then this method only returns the first
+           device.  This may impact Nvidia GPU logging in the case multiple
+           GPU cards are used.
+
+
+        Returns
+        -------
+
+        device
+            The **first** torch device (if a list of ids is set).
+        """
+
+        if self.device_type in ("cpu", "mps"):
+            return torch.device(self.device_type)
+        elif self.device_type == "cuda":
+            if not self.device_ids:
+                return torch.device(self.device_type)
+            else:
+                return torch.device(self.device_type, self.device_ids[0])
+
+        # if you get to this point, this is an unexpected RuntimeError
+        raise RuntimeError(
+            f"Unexpected device type {self.device_type} lacks support"
+        )
+
+    def lightning_accelerator(self) -> tuple[str, int | list[int] | str | None]:
+        """Returns the lightning accelerator setup.
+
+        Returns
+        -------
+
+        accelerator
+            The lightning accelerator to use
+
+        devices
+            The lightning devices to use
+        """
+
+        devices: int | list[int] | str = self.device_ids
+        if not devices:
+            devices = "auto"
+        elif self.device_type == "mps":
+            devices = 1
+
+        return self.device_type, devices
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 2c8bdc55f161ce567ddd7cb6641ac728ce3b7dbd..10121af1091a0e3a186efd4c338f4af0ad4b9cf5 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -7,70 +7,73 @@ import logging
 import os
 import shutil
 
-from lightning.pytorch import Trainer
-from lightning.pytorch.callbacks import ModelCheckpoint
-from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
-from lightning.pytorch.utilities.model_summary import ModelSummary
+import lightning.pytorch
+import lightning.pytorch.callbacks
+import lightning.pytorch.loggers
+import torch.nn
 
-from ..utils.accelerator import AcceleratorProcessor
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
 from .callbacks import LoggingCallback
 
 logger = logging.getLogger(__name__)
 
 
-def check_gpu(device):
-    """Check the device type and the availability of GPU.
-
-    Parameters
-    ----------
-
-    device : :py:class:`torch.device`
-        device to use
-    """
-    if device == "cuda":
-        # asserts we do have a GPU
-        assert bool(
-            gpu_constants()
-        ), f"Device set to '{device}', but nvidia-smi is not installed"
-
-
-def save_model_summary(output_folder, model):
+def save_model_summary(
+    output_folder: str, model: torch.nn.Module
+) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
     """Save a little summary of the model in a txt file.
 
     Parameters
     ----------
 
-    output_folder : str
+    output_folder
         output path
 
-    model : :py:class:`torch.nn.Module`
+    model
         Network (e.g. driu, hed, unet)
 
     Returns
     -------
-    r : str
+    summary:
         The model summary in a text format.
 
-    n : int
+    total_parameters:
         The number of parameters of the model.
     """
     summary_path = os.path.join(output_folder, "model_summary.txt")
     logger.info(f"Saving model summary at {summary_path}...")
     with open(summary_path, "w") as f:
-        summary = ModelSummary(model, max_depth=-1)
+        summary = lightning.pytorch.utilities.model_summary.ModelSummary(
+            model, max_depth=-1
+        )
         f.write(str(summary))
-    return summary, ModelSummary(model).total_parameters
+    return (
+        summary,
+        lightning.pytorch.utilities.model_summary.ModelSummary(
+            model
+        ).total_parameters,
+    )
 
 
-def static_information_to_csv(static_logfile_name, device, n):
-    """Save the static information in a csv file.
+def static_information_to_csv(
+    static_logfile_name: str,
+    device_type: str,
+    model_size: int,
+) -> None:
+    """Saves the static information in a CSV file.
 
     Parameters
     ----------
 
-    static_logfile_name : str
-        The static file name which is a join between the output folder and "constant.csv"
+    static_logfile_name
+        The static file name which is a join between the output folder and
+        "constant.csv"
+
+    device_type
+        The type of device we are using
+
+    model_size
+        The size of the model we will be training
     """
     if os.path.exists(static_logfile_name):
         backup = static_logfile_name + "~"
@@ -78,80 +81,23 @@ def static_information_to_csv(static_logfile_name, device, n):
             os.unlink(backup)
         shutil.move(static_logfile_name, backup)
     with open(static_logfile_name, "w", newline="") as f:
-        logdata = cpu_constants()
-        if device == "cuda":
-            logdata += gpu_constants()
-        logdata += (("model_size", n),)
-        logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
+        logdata: dict[str, int | float | str] = {}
+        logdata.update(cpu_constants())
+        if device_type == "cuda":
+            results = gpu_constants()
+            if results is not None:
+                logdata.update(results)
+        logdata["model_size"] = model_size
+        logwriter = csv.DictWriter(f, fieldnames=logdata.keys())
         logwriter.writeheader()
-        logwriter.writerow(dict(k for k in logdata))
-
-
-def check_exist_logfile(logfile_name, arguments):
-    """Check existance of logfile (trainlog.csv), If the logfile exist the and
-    the epochs number are still 0, The logfile will be replaced.
-
-    Parameters
-    ----------
-
-    logfile_name : str
-        The logfile_name which is a join between the output_folder and trainlog.csv
-
-    arguments : dict
-        start and end epochs
-    """
-    if arguments["epoch"] == 0 and os.path.exists(logfile_name):
-        backup = logfile_name + "~"
-        if os.path.exists(backup):
-            os.unlink(backup)
-        shutil.move(logfile_name, backup)
-
-
-def create_logfile_fields(valid_loader, extra_valid_loaders, device):
-    """Creation of the logfile fields that will appear in the logfile.
-
-    Parameters
-    ----------
-
-    valid_loader : :py:class:`torch.utils.data.DataLoader`
-        To be used to validate the model and enable automatic checkpointing.
-        If set to ``None``, then do not validate it.
-
-    extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
-        To be used to validate the model, however **does not affect** automatic
-        checkpointing. If set to ``None``, or empty, then does not log anything
-        else.  Otherwise, an extra column with the loss of every dataset in
-        this list is kept on the final training log.
-
-    device : :py:class:`torch.device`
-        device to use
-
-    Returns
-    -------
-
-    logfile_fields: tuple
-        The fields that will appear in trainlog.csv
-    """
-    logfile_fields = (
-        "epoch",
-        "total_time",
-        "eta",
-        "loss",
-        "learning_rate",
-    )
-    if valid_loader is not None:
-        logfile_fields += ("validation_loss",)
-    if extra_valid_loaders:
-        logfile_fields += ("extra_validation_losses",)
-    logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda"))
-    return logfile_fields
+        logwriter.writerow(logdata)
 
 
 def run(
     model,
     datamodule,
     checkpoint_period,
-    accelerator,
+    device_manager,
     arguments,
     output_folder,
     monitoring_interval,
@@ -187,8 +133,8 @@ def run(
         Save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
         not save intermediary checkpoints.
 
-    accelerator : str
-        A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
+    device_manager : DeviceManager
+        A device, to be used for training.
 
     arguments : dict
         Start and end epochs:
@@ -210,30 +156,30 @@ def run(
 
     max_epoch = arguments["max_epoch"]
 
-    accelerator_processor = AcceleratorProcessor(accelerator)
-
     os.makedirs(output_folder, exist_ok=True)
 
     # Save model summary
-    r, n = save_model_summary(output_folder, model)
+    _, no_of_parameters = save_model_summary(output_folder, model)
 
-    csv_logger = CSVLogger(output_folder, "logs_csv")
-    tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
+    csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
+    tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
+        output_folder, "logs_tensorboard"
+    )
 
     resource_monitor = ResourceMonitor(
         interval=monitoring_interval,
-        has_gpu=(accelerator_processor.accelerator == "gpu"),
+        has_gpu=device_manager.device_type == "cuda",
         main_pid=os.getpid(),
         logging_level=logging.ERROR,
     )
 
-    checkpoint_callback = ModelCheckpoint(
+    checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint(
         output_folder,
         "model_lowest_valid_loss",
         save_last=True,
         monitor="validation_loss",
         mode="min",
-        save_on_train_epoch_end=False,
+        save_on_train_epoch_end=True,
         every_n_epochs=checkpoint_period,
     )
 
@@ -242,17 +188,15 @@ def run(
     # write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
     static_information_to_csv(
-        static_logfile_name, accelerator_processor.to_torch(), n
+        static_logfile_name,
+        device_manager.device_type,
+        no_of_parameters,
     )
 
-    if accelerator_processor.device is None:
-        devices = "auto"
-    else:
-        devices = accelerator_processor.device
-
     with resource_monitor:
-        trainer = Trainer(
-            accelerator=accelerator_processor.accelerator,
+        accelerator, devices = device_manager.lightning_accelerator()
+        trainer = lightning.pytorch.Trainer(
+            accelerator=accelerator,
             devices=devices,
             max_epochs=max_epoch,
             accumulate_grad_batches=batch_chunk_count,
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index ba9bf05f7428d759489bd744f8ec35c3b43bab02..a878a076037925879c072bffb87d23a3e1ce7b0d 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -2,55 +2,185 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import logging
+import typing
+
 import lightning.pytorch as pl
 import torch
-import torch.nn as nn
+import torch.nn
+import torch.nn.functional as F
+import torch.optim.optimizer
+import torch.utils.data
 import torchvision.models as models
+import torchvision.transforms
+
+from ..data.typing import DataLoader, TransformSequence
 
-from .normalizer import TorchVisionNormalizer
+logger = logging.getLogger(__name__)
 
 
 class Alexnet(pl.LightningModule):
     """Alexnet module.
 
     Note: only usable with a normalized dataset
+
+    Parameters
+    ----------
+
+    train_loss
+        The loss to be used during the training.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    validation_loss
+        The loss to be used for validation (may be different from the training
+        loss).  If extra-validation sets are provided, the same loss will be
+        used throughout.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    optimizer_type
+        The type of optimizer to use for training
+
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+
+    pretrained
+        If set to True, loads pretrained model weights during initialization, else trains a new model.
     """
 
     def __init__(
         self,
-        criterion,
-        criterion_valid,
-        optimizer,
-        optimizer_configs,
-        pretrained=False,
+        train_loss: torch.nn.Module,
+        validation_loss: torch.nn.Module | None,
+        optimizer_type: type[torch.optim.Optimizer],
+        optimizer_arguments: dict[str, typing.Any],
+        augmentation_transforms: TransformSequence = [],
+        pretrained: bool = False,
     ):
         super().__init__()
 
-        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
+        self.name = "alexnet"
 
-        self.name = "AlexNet"
+        self._train_loss = train_loss
+        self._validation_loss = (
+            validation_loss if validation_loss is not None else train_loss
+        )
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
 
-        # Load pretrained model
-        weights = (
-            None if pretrained is False else models.AlexNet_Weights.DEFAULT
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
         )
-        self.model_ft = models.alexnet(weights=weights)
 
-        self.normalizer = TorchVisionNormalizer(nb_channels=1)
+        self.pretrained = pretrained
+
+        # Load pretrained model
+        if not pretrained:
+            weights = None
+        else:
+            logger.info(f"Loading pretrained {self.name} model weights")
+            weights = models.AlexNet_Weights.DEFAULT
+
+        self.model_ft = models.alexnet(weights=weights)
 
         # Adapt output features
-        self.model_ft.classifier[4] = nn.Linear(4096, 512)
-        self.model_ft.classifier[6] = nn.Linear(512, 1)
+        self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
+        self.model_ft.classifier[6] = torch.nn.Linear(512, 1)
 
     def forward(self, x):
-        x = self.normalizer(x)
+        x = self.normalizer(x)  # type: ignore
+
         x = self.model_ft(x)
 
         return x
 
-    def training_step(self, batch, batch_idx):
-        images = batch[1]
-        labels = batch[2]
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initializes the normalizer for the current model.
+
+        This function is NOOP if ``pretrained = True`` (normalizer set to
+        imagenet weights, during contruction).
+
+        Parameters
+        ----------
+
+        dataloader: :py:class:`torch.utils.data.DataLoader`
+            A torch Dataloader from which to compute the mean and std.
+            Will not be used if the model is pretrained.
+        """
+        if self.pretrained:
+            from .normalizer import make_imagenet_normalizer
+
+            logger.warning(
+                f"ImageNet pre-trained {self.name} model - NOT "
+                f"computing z-norm factors from train dataloader. "
+                f"Using preset factors from torchvision."
+            )
+            self.normalizer = make_imagenet_normalizer()
+        else:
+            from .normalizer import make_z_normalizer
+
+            logger.info(
+                f"Uninitialised {self.name} model - "
+                f"computing z-norm factors from train dataloader."
+            )
+            self.normalizer = make_z_normalizer(dataloader)
+
+    def balance_losses_by_class(
+        self, train_dataloader: DataLoader, valid_dataloader: DataLoader
+    ):
+        """Reweights loss weights if possible.
+
+        Parameters
+        ----------
+
+        train_dataloader
+            The data loader to use for training
+
+        valid_dataloader
+            The data loader to use for validation
+
+
+        Raises
+        ------
+
+        RuntimeError
+            If train or validation losses are not of type
+            :py:class:`torch.nn.BCEWithLogitsLoss`.
+        """
+        from .loss_weights import get_label_weights
+
+        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss training loss.")
+            weights = get_label_weights(train_dataloader)
+            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+        else:
+            raise RuntimeError(
+                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
+            )
+
+        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
+            weights = get_label_weights(valid_dataloader)
+            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
+        else:
+            raise RuntimeError(
+                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
+            )
+
+    def training_step(self, batch, _):
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -58,17 +188,18 @@ class Alexnet(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        outputs = self(images)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion = self.hparams.criterion.to(self.device)
-        training_loss = self.hparams.criterion(outputs, labels.float())
+        augmented_images = [
+            self._augmentation_transforms(img).to(self.device) for img in images
+        ]
+        # Combine list of augmented images back into a tensor
+        augmented_images = torch.cat(augmented_images, 0).view(images.shape)
+        outputs = self(augmented_images)
 
-        return {"loss": training_loss}
+        return self._train_loss(outputs, labels.float())
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[1]
-        labels = batch[2]
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -78,34 +209,23 @@ class Alexnet(pl.LightningModule):
         # data forwarding on the existing network
         outputs = self(images)
 
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
-            self.device
-        )
-        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-        else:
-            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch[0]
-        images = batch[1]
+        images = batch[0]
+        labels = batch[1]["label"]
+        names = batch[1]["names"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
 
-        # necessary check for HED architecture that uses several outputs
-        # for loss calculation instead of just the last concatfuse block
-        if isinstance(outputs, list):
-            outputs = outputs[-1]
-
-        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+        return (
+            names[0],
+            torch.flatten(probabilities),
+            torch.flatten(labels),
+        )
 
     def configure_optimizers(self):
-        optimizer = getattr(torch.optim, self.hparams.optimizer)(
-            self.parameters(), **self.hparams.optimizer_configs
+        return self._optimizer_type(
+            self.parameters(), **self._optimizer_arguments
         )
-
-        return optimizer
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index a7cf9d567946899c874efce1664d0e7f65d5ace2..8eba3b53410da4874d8b59c48c73e13bec1cb703 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -2,55 +2,187 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import logging
+import typing
+
 import lightning.pytorch as pl
 import torch
-import torch.nn as nn
+import torch.nn
+import torch.nn.functional as F
+import torch.optim.optimizer
+import torch.utils.data
 import torchvision.models as models
+import torchvision.transforms
+
+from ..data.typing import DataLoader, TransformSequence
 
-from .normalizer import TorchVisionNormalizer
+logger = logging.getLogger(__name__)
 
 
 class Densenet(pl.LightningModule):
-    """Densenet module.
+    """Densenet-121 module.
+
+    Parameters
+    ----------
+
+    train_loss
+        The loss to be used during the training.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    validation_loss
+        The loss to be used for validation (may be different from the training
+        loss).  If extra-validation sets are provided, the same loss will be
+        used throughout.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    optimizer_type
+        The type of optimizer to use for training
+
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
 
-    Note: only usable with a normalized dataset
+    pretrained
+        If set to True, loads pretrained model weights during initialization, else trains a new model.
     """
 
     def __init__(
         self,
-        criterion,
-        criterion_valid,
-        optimizer,
-        optimizer_configs,
-        pretrained=False,
-        nb_channels=3,
+        train_loss: torch.nn.Module,
+        validation_loss: torch.nn.Module | None,
+        optimizer_type: type[torch.optim.Optimizer],
+        optimizer_arguments: dict[str, typing.Any],
+        augmentation_transforms: TransformSequence = [],
+        pretrained: bool= False,
     ):
         super().__init__()
 
-        self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
+        self.name = "densenet-121"
 
-        self.name = "Densenet"
+        self._train_loss = train_loss
+        self._validation_loss = (
+            validation_loss if validation_loss is not None else train_loss
+        )
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
+
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
+        )
 
-        self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
+        self.pretrained = pretrained
 
         # Load pretrained model
-        weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
+        if not pretrained:
+            weights = None
+        else:
+            logger.info(f"Loading pretrained {self.name} model weights")
+            weights = models.DenseNet121_Weights.DEFAULT
+
         self.model_ft = models.densenet121(weights=weights)
 
         # Adapt output features
-        self.model_ft.classifier = nn.Sequential(
-            nn.Linear(1024, 256), nn.Linear(256, 1)
+        self.model_ft.classifier = torch.nn.Sequential(
+            torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1)
         )
 
     def forward(self, x):
-        x = self.normalizer(x)
+        
+        x = self.normalizer(x)  # type: ignore
+
         x = self.model_ft(x)
 
         return x
 
-    def training_step(self, batch, batch_idx):
-        images = batch[1]
-        labels = batch[2]
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initializes the normalizer for the current model.
+
+        This function is NOOP if ``pretrained = True`` (normalizer set to
+        imagenet weights, during contruction).
+
+        Parameters
+        ----------
+
+        dataloader: :py:class:`torch.utils.data.DataLoader`
+            A torch Dataloader from which to compute the mean and std.
+            Will not be used if the model is pretrained.
+        """
+        if self.pretrained:
+            from .normalizer import make_imagenet_normalizer
+
+            logger.warning(
+                f"ImageNet pre-trained {self.name} model - NOT "
+                f"computing z-norm factors from train dataloader. "
+                f"Using preset factors from torchvision."
+            )
+            self.normalizer = make_imagenet_normalizer()
+        else:
+            from .normalizer import make_z_normalizer
+
+            logger.info(
+                f"Uninitialised {self.name} model - "
+                f"computing z-norm factors from train dataloader."
+            )
+            self.normalizer = make_z_normalizer(dataloader)
+
+    def balance_losses_by_class(
+        self,
+        train_dataloader: DataLoader,
+        valid_dataloader: dict[str, DataLoader],
+    ):
+        """Reweights loss weights if possible.
+
+        Parameters
+        ----------
+
+        train_dataloader
+            The data loader to use for training
+
+        valid_dataloader
+            The data loaders to use for each of the validation sets
+
+
+        Raises
+        ------
+
+        RuntimeError
+            If train or validation losses are not of type
+            :py:class:`torch.nn.BCEWithLogitsLoss`.
+        """
+        from .loss_weights import get_label_weights
+
+        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss training loss.")
+            weights = get_label_weights(train_dataloader)
+            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+        else:
+            raise RuntimeError(
+                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
+            )
+
+        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
+            weights = get_label_weights(valid_dataloader)
+            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
+        else:
+            raise RuntimeError(
+                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
+            )
+
+    def training_step(self, batch, _):
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -58,17 +190,18 @@ class Densenet(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        outputs = self(images)
-
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion = self.hparams.criterion.to(self.device)
-        training_loss = self.hparams.criterion(outputs, labels.float())
+        augmented_images = [
+            self._augmentation_transforms(img).to(self.device) for img in images
+        ]
+        # Combine list of augmented images back into a tensor
+        augmented_images = torch.cat(augmented_images, 0).view(images.shape)
+        outputs = self(augmented_images)
 
-        return {"loss": training_loss}
+        return self._train_loss(outputs, labels.float())
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[1]
-        labels = batch[2]
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -78,35 +211,23 @@ class Densenet(pl.LightningModule):
         # data forwarding on the existing network
         outputs = self(images)
 
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
-            self.device
-        )
-        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-        else:
-            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch[0]
-        images = batch[1]
+        images = batch[0]
+        labels = batch[1]["label"]
+        names = batch[1]["names"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
 
-        # necessary check for HED architecture that uses several outputs
-        # for loss calculation instead of just the last concatfuse block
-        if isinstance(outputs, list):
-            outputs = outputs[-1]
-
-        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+        return (
+            names[0],
+            torch.flatten(probabilities),
+            torch.flatten(labels),
+        )
 
     def configure_optimizers(self):
-        # Dynamically instantiates the optimizer given the configs
-        optimizer = getattr(torch.optim, self.hparams.optimizer)(
-            self.parameters(), **self.hparams.optimizer_configs
+        return self._optimizer_type(
+            self.parameters(), **self._optimizer_arguments
         )
-
-        return optimizer
diff --git a/src/ptbench/models/loss_weights.py b/src/ptbench/models/loss_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..6889b2539fb1071e228c4487540ad9272aad808c
--- /dev/null
+++ b/src/ptbench/models/loss_weights.py
@@ -0,0 +1,70 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+
+import torch
+import torch.utils.data
+
+logger = logging.getLogger(__name__)
+
+
+def get_label_weights(
+    dataloader: torch.utils.data.DataLoader,
+) -> torch.Tensor:
+    """Computes the weights of each class of a DataLoader.
+
+    This function inputs a pytorch DataLoader and computes the ratio between
+    number of negative and positive samples (scalar).  The weight can be used
+    to adjust minimisation criteria to in cases there is a huge data imbalance.
+
+    If
+
+    It returns a vector with weights (inverse counts) for each label.
+
+
+    Parameters
+    ----------
+
+    dataloader
+        A DataLoader from which to compute the positive weights.  Entries must
+        be a dictionary which must contain a ``label`` key.
+
+
+    Returns
+    -------
+
+    positive_weights
+        the positive weight of each class in the dataset given as input
+    """
+
+    targets = torch.tensor(
+        [sample for batch in dataloader for sample in batch[1]["label"]]
+    )
+
+    # Binary labels
+    if len(list(targets.shape)) == 1:
+        class_sample_count = [
+            float((targets == t).sum().item())
+            for t in torch.unique(targets, sorted=True)
+        ]
+
+        # Divide negatives by positives
+        positive_weights = torch.tensor(
+            [class_sample_count[0] / class_sample_count[1]]
+        ).reshape(-1)
+
+    # Multiclass labels
+    else:
+        class_sample_count = torch.sum(targets, dim=0)
+        negative_class_sample_count = (
+            torch.full((targets.size()[1],), float(targets.size()[0]))
+            - class_sample_count
+        )
+
+        positive_weights = negative_class_sample_count / (
+            class_sample_count + negative_class_sample_count
+        )
+
+    return positive_weights
diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py
index aa2142a6bb01de4dfc011cb8b49c1521b82a5e29..ce68f4b558b2812d17a8f54954e197a086233a52 100644
--- a/src/ptbench/models/normalizer.py
+++ b/src/ptbench/models/normalizer.py
@@ -2,37 +2,74 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""A network model that prefixes a z-normalization step to any other module."""
+"""A network model that prefixes a subtract/divide step to any other module."""
 
 import torch
 import torch.nn
+import torch.utils.data
+import torchvision.transforms
+import tqdm
 
 
-class TorchVisionNormalizer(torch.nn.Module):
-    """A simple normalizer that applies the standard torchvision normalization.
+def make_z_normalizer(
+    dataloader: torch.utils.data.DataLoader,
+) -> torchvision.transforms.Normalize:
+    """Computes mean and standard deviation from a dataloader.
+
+    This function will input a dataloader, and compute the mean and standard
+    deviation by image channel.  It will work for both monochromatic, and color
+    inputs with 2, 3 or more color planes.
 
-    This module does not learn.
 
     Parameters
     ----------
 
-    nb_channels : :py:class:`int`, Optional
-        Number of images channels fed to the model
+    dataloader:
+        A torch Dataloader from which to compute the mean and std
+
+
+    Returns
+    -------
+        An initialized normalizer
+    """
+
+    # Peek the number of channels of batches in the data loader
+    batch = next(iter(dataloader))
+    channels = batch[0].shape[1]
+
+    # Initialises accumulators
+    mean = torch.zeros(channels, dtype=batch[0].dtype)
+    var = torch.zeros(channels, dtype=batch[0].dtype)
+    num_images = 0
+
+    # Evaluates mean and standard deviation
+    for batch in tqdm.tqdm(dataloader, unit="batch"):
+        data = batch[0]
+        data = data.view(data.size(0), data.size(1), -1)
+
+        num_images += data.size(0)
+        mean += data.mean(2).sum(0)
+        var += data.var(2).sum(0)
+
+    mean /= num_images
+    var /= num_images
+    std = torch.sqrt(var)
+
+    return torchvision.transforms.Normalize(mean, std)
+
+
+def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
+    """Returns the stock ImageNet normalisation weights from torchvision.
+
+    The weights are wrapped in a torch module.  This normalizer only works for
+    **RGB (color) images**.
+
+
+    Returns
+    -------
+        An initialized normalizer
     """
 
-    def __init__(self, nb_channels=3):
-        super().__init__()
-        mean = torch.zeros(nb_channels)[None, :, None, None]
-        std = torch.ones(nb_channels)[None, :, None, None]
-        self.register_buffer("mean", mean)
-        self.register_buffer("std", std)
-        self.name = "torchvision-normalizer"
-
-    def set_mean_std(self, mean, std):
-        mean = torch.as_tensor(mean)[None, :, None, None]
-        std = torch.as_tensor(std)[None, :, None, None]
-        self.register_buffer("mean", mean)
-        self.register_buffer("std", std)
-
-    def forward(self, inputs):
-        return inputs.sub(self.mean).div(self.std)
+    return torchvision.transforms.Normalize(
+        (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
+    )
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 9da9702f31436df441cdb56d9415cc06e6829623..20bbb0dd9d06a122c81187762b8bbca0d04470fd 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -2,85 +2,143 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import logging
+import typing
+
 import lightning.pytorch as pl
 import torch
-import torch.nn as nn
+import torch.nn
 import torch.nn.functional as F
+import torch.optim.optimizer
+import torch.utils.data
+import torchvision.transforms
+
+from ..data.typing import DataLoader, TransformSequence
+
+logger = logging.getLogger(__name__)
+
+
+class Pasa(pl.LightningModule):
+    """Implementation of CNN by Pasa.
+
+    Simple CNN for classification based on paper by [PASA-2019]_.
+
+
+    Parameters
+    ----------
+
+    train_loss
+        The loss to be used during the training.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    validation_loss
+        The loss to be used for validation (may be different from the training
+        loss).  If extra-validation sets are provided, the same loss will be
+        used throughout.
 
-from .normalizer import TorchVisionNormalizer
+        .. warning::
 
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
 
-class PASA(pl.LightningModule):
-    """PASA module.
+    optimizer_type
+        The type of optimizer to use for training
 
-    Based on paper by [PASA-2019]_.
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
     """
 
     def __init__(
         self,
-        criterion,
-        criterion_valid,
-        optimizer,
-        optimizer_configs,
+        train_loss: torch.nn.Module,
+        validation_loss: torch.nn.Module | None,
+        optimizer_type: type[torch.optim.Optimizer],
+        optimizer_arguments: dict[str, typing.Any],
+        augmentation_transforms: TransformSequence = [],
     ):
         super().__init__()
 
-        self.save_hyperparameters()
-
         self.name = "pasa"
 
-        self.normalizer = TorchVisionNormalizer(nb_channels=1)
+        self._train_loss = train_loss
+        self._validation_loss = (
+            validation_loss if validation_loss is not None else train_loss
+        )
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
+
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
+        )
 
         # First convolution block
-        self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
-        self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
-        self.fc3 = nn.Conv2d(1, 16, (1, 1), (4, 4))
+        self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
+        self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
+        self.fc3 = torch.nn.Conv2d(1, 16, (1, 1), (4, 4))
 
-        self.batchNorm2d_4 = nn.BatchNorm2d(4)
-        self.batchNorm2d_16 = nn.BatchNorm2d(16)
-        self.batchNorm2d_16_2 = nn.BatchNorm2d(16)
+        self.batchNorm2d_4 = torch.nn.BatchNorm2d(4)
+        self.batchNorm2d_16 = torch.nn.BatchNorm2d(16)
+        self.batchNorm2d_16_2 = torch.nn.BatchNorm2d(16)
 
         # Second convolution block
-        self.fc4 = nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1))
-        self.fc5 = nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1))
-        self.fc6 = nn.Conv2d(16, 32, (1, 1), (1, 1))  # Original stride (2, 2)
+        self.fc4 = torch.nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1))
+        self.fc5 = torch.nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1))
+        self.fc6 = torch.nn.Conv2d(
+            16, 32, (1, 1), (1, 1)
+        )  # Original stride (2, 2)
 
-        self.batchNorm2d_24 = nn.BatchNorm2d(24)
-        self.batchNorm2d_32 = nn.BatchNorm2d(32)
-        self.batchNorm2d_32_2 = nn.BatchNorm2d(32)
+        self.batchNorm2d_24 = torch.nn.BatchNorm2d(24)
+        self.batchNorm2d_32 = torch.nn.BatchNorm2d(32)
+        self.batchNorm2d_32_2 = torch.nn.BatchNorm2d(32)
 
         # Third convolution block
-        self.fc7 = nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1))
-        self.fc8 = nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1))
-        self.fc9 = nn.Conv2d(32, 48, (1, 1), (1, 1))  # Original stride (2, 2)
+        self.fc7 = torch.nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1))
+        self.fc8 = torch.nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1))
+        self.fc9 = torch.nn.Conv2d(
+            32, 48, (1, 1), (1, 1)
+        )  # Original stride (2, 2)
 
-        self.batchNorm2d_40 = nn.BatchNorm2d(40)
-        self.batchNorm2d_48 = nn.BatchNorm2d(48)
-        self.batchNorm2d_48_2 = nn.BatchNorm2d(48)
+        self.batchNorm2d_40 = torch.nn.BatchNorm2d(40)
+        self.batchNorm2d_48 = torch.nn.BatchNorm2d(48)
+        self.batchNorm2d_48_2 = torch.nn.BatchNorm2d(48)
 
         # Fourth convolution block
-        self.fc10 = nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1))
-        self.fc11 = nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1))
-        self.fc12 = nn.Conv2d(48, 64, (1, 1), (1, 1))  # Original stride (2, 2)
+        self.fc10 = torch.nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1))
+        self.fc11 = torch.nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1))
+        self.fc12 = torch.nn.Conv2d(
+            48, 64, (1, 1), (1, 1)
+        )  # Original stride (2, 2)
 
-        self.batchNorm2d_56 = nn.BatchNorm2d(56)
-        self.batchNorm2d_64 = nn.BatchNorm2d(64)
-        self.batchNorm2d_64_2 = nn.BatchNorm2d(64)
+        self.batchNorm2d_56 = torch.nn.BatchNorm2d(56)
+        self.batchNorm2d_64 = torch.nn.BatchNorm2d(64)
+        self.batchNorm2d_64_2 = torch.nn.BatchNorm2d(64)
 
         # Fifth convolution block
-        self.fc13 = nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1))
-        self.fc14 = nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1))
-        self.fc15 = nn.Conv2d(64, 80, (1, 1), (1, 1))  # Original stride (2, 2)
+        self.fc13 = torch.nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1))
+        self.fc14 = torch.nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1))
+        self.fc15 = torch.nn.Conv2d(
+            64, 80, (1, 1), (1, 1)
+        )  # Original stride (2, 2)
 
-        self.batchNorm2d_72 = nn.BatchNorm2d(72)
-        self.batchNorm2d_80 = nn.BatchNorm2d(80)
-        self.batchNorm2d_80_2 = nn.BatchNorm2d(80)
+        self.batchNorm2d_72 = torch.nn.BatchNorm2d(72)
+        self.batchNorm2d_80 = torch.nn.BatchNorm2d(80)
+        self.batchNorm2d_80_2 = torch.nn.BatchNorm2d(80)
 
-        self.pool2d = nn.MaxPool2d((3, 3), (2, 2))  # Pool after conv. block
-        self.dense = nn.Linear(80, 1)  # Fully connected layer
+        self.pool2d = torch.nn.MaxPool2d(
+            (3, 3), (2, 2)
+        )  # Pool after conv. block
+        self.dense = torch.nn.Linear(80, 1)  # Fully connected layer
 
     def forward(self, x):
-        x = self.normalizer(x)
+        x = self.normalizer(x)  # type: ignore
 
         # First convolution block
         _x = x
@@ -127,9 +185,70 @@ class PASA(pl.LightningModule):
 
         return x
 
-    def training_step(self, batch, batch_idx):
-        images = batch["data"]
-        labels = batch["label"]
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initializes the input normalizer for the current model.
+
+        Parameters
+        ----------
+
+        dataloader
+            A torch Dataloader from which to compute the mean and std
+        """
+        from .normalizer import make_z_normalizer
+
+        logger.info(
+            f"Uninitialised {self.name} model - "
+            f"computing z-norm factors from train dataloader."
+        )
+        self.normalizer = make_z_normalizer(dataloader)
+
+    def balance_losses_by_class(
+        self,
+        train_dataloader: DataLoader,
+        valid_dataloader: dict[str, DataLoader],
+    ):
+        """Reweights loss weights if possible.
+
+        Parameters
+        ----------
+
+        train_dataloader
+            The data loader to use for training
+
+        valid_dataloader
+            The data loaders to use for each of the validation sets
+
+
+        Raises
+        ------
+
+        RuntimeError
+            If train or validation losses are not of type
+            :py:class:`torch.nn.BCEWithLogitsLoss`.
+        """
+        from .loss_weights import get_label_weights
+
+        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss training loss.")
+            weights = get_label_weights(train_dataloader)
+            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+        else:
+            raise RuntimeError(
+                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
+            )
+
+        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
+            weights = get_label_weights(valid_dataloader)
+            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
+        else:
+            raise RuntimeError(
+                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
+            )
+
+    def training_step(self, batch, _):
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -137,17 +256,18 @@ class PASA(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        outputs = self(images)
+        augmented_images = [
+            self._augmentation_transforms(img).to(self.device) for img in images
+        ]
+        # Combine list of augmented images back into a tensor
+        augmented_images = torch.cat(augmented_images, 0).view(images.shape)
+        outputs = self(augmented_images)
 
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion = self.hparams.criterion.to(self.device)
-        training_loss = self.hparams.criterion(outputs, labels.double())
-
-        return {"loss": training_loss}
+        return self._train_loss(outputs, labels.float())
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch["data"]
-        labels = batch["label"]
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -157,63 +277,23 @@ class PASA(pl.LightningModule):
         # data forwarding on the existing network
         outputs = self(images)
 
-        # Manually move criterion to selected device, since not part of the model.
-        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
-            self.device
-        )
-        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-        else:
-            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch["name"]
-        images = batch["data"]
-        labels = batch["label"]
+        images = batch[0]
+        labels = batch[1]["label"]
+        names = batch[1]["names"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
 
-        # necessary check for HED architecture that uses several outputs
-        # for loss calculation instead of just the last concatfuse block
-        if isinstance(outputs, list):
-            outputs = outputs[-1]
-
-        results = (
+        return (
             names[0],
             torch.flatten(probabilities),
             torch.flatten(labels),
         )
 
-        return results
-        # {
-        # f"dataloader_{dataloader_idx}_predictions": (
-        #    names[0],
-        #    torch.flatten(probabilities),
-        #    torch.flatten(labels),
-        # )
-        # }
-
-    # def on_predict_epoch_end(self):
-
-    #    retval = defaultdict(list)
-
-    #    for dataloader_name, predictions in self.predictions_cache.items():
-    #        for prediction in predictions:
-    #            retval[dataloader_name]["name"].append(prediction[0])
-    #            retval[dataloader_name]["prediction"].append(prediction[1])
-    #            retval[dataloader_name]["label"].append(prediction[2])
-
-    # Need to cache predictions in the predict step, then reorder by key
-    # Clear prediction dict
-    # raise NotImplementedError
-
     def configure_optimizers(self):
-        # Dynamically instantiates the optimizer given the configs
-        optimizer = getattr(torch.optim, self.hparams.optimizer)(
-            self.parameters(), **self.hparams.optimizer_configs
+        return self._optimizer_type(
+            self.parameters(), **self._optimizer_arguments
         )
-
-        return optimizer
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 4d2a226b5b0b479b3f84c748d24aebd43d8d6dea..664b8b1ad1ae38625a22af4a1092b15da20b2727 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -6,9 +6,6 @@ import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
-from lightning.pytorch import seed_everything
-
-from ..utils.checkpointer import get_checkpoint
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -19,7 +16,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     epilog="""Examples:
 
 \b
-    1. Trains PASA model with Montgomery dataset, on a GPU (``cuda:0``):
+    1. Trains Pasa's model with Montgomery dataset, on a GPU (``cuda:0``):
 
        .. code:: sh
 
@@ -39,47 +36,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--model",
     "-m",
-    help="A torch.nn.Module instance implementing the network to be trained",
+    help="A lightining module instance implementing the network to be trained",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--datamodule",
     "-d",
-    help="A dictionary mapping string keys to "
-    "torch.utils.data.dataset.Dataset instances implementing datasets "
-    "to be used for training and validating the model, possibly including all "
-    "pre-processing pipelines required or, optionally, a dictionary mapping "
-    "string keys to torch.utils.data.dataset.Dataset instances.  At least "
-    "one key named ``train`` must be available.  This dataset will be used for "
-    "training the network model.  The dataset description must include all "
-    "required pre-processing, including eventual data augmentation.  If a "
-    "dataset named ``__train__`` is available, it is used prioritarily for "
-    "training instead of ``train``.  If a dataset named ``__valid__`` is "
-    "available, it is used for model validation (and automatic "
-    "check-pointing) at each epoch.  If a dataset list named "
-    "``__extra_valid__`` is available, then it will be tracked during the "
-    "validation process and its loss output at the training log as well, "
-    "in the format of an array occupying a single column.  All other keys "
-    "are considered test datasets and are ignored during training",
-    required=True,
-    cls=ResourceOption,
-)
-@click.option(
-    "--criterion",
-    help="A loss function to compute the CNN error for every sample "
-    "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
+    help="A lighting data module containing the training and validation sets.",
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--criterion-valid",
-    help="A specific loss function for the validation set to compute the CNN"
-    "error for every sample respecting the PyTorch API for loss functions"
-    "(see torch.nn.modules.loss)",
-    required=False,
-    cls=ResourceOption,
-)
 @click.option(
     "--batch-size",
     "-b",
@@ -157,17 +124,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @click.option(
-    "--accelerator",
-    "-a",
-    help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
+    "--device",
+    "-d",
+    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
     show_default=True,
     required=True,
     default="cpu",
     cls=ResourceOption,
 )
 @click.option(
-    "--cache-samples",
-    help="If set to True, loads the sample into memory, otherwise loads them at runtime.",
+    "--cache-samples/--no-cache-samples",
+    help="If set to True, loads the sample into memory, "
+    "otherwise loads them at runtime.",
     required=True,
     show_default=True,
     default=False,
@@ -196,16 +164,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     default=-1,
     cls=ResourceOption,
 )
-@click.option(
-    "--normalization",
-    "-n",
-    help="Z-Normalization of input images: 'imagenet' for ImageNet parameters,"
-    " 'current' for parameters of the current trainset, "
-    "'none' for no normalization.",
-    required=False,
-    default="none",
-    cls=ResourceOption,
-)
 @click.option(
     "--monitoring-interval",
     "-I",
@@ -224,12 +182,25 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 )
 @click.option(
     "--resume-from",
-    help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.",
+    help="Which checkpoint to resume training from. If set, can be one of "
+    "`best`, `last`, or a path to a model checkpoint.",
     type=str,
     required=False,
     default=None,
     cls=ResourceOption,
 )
+@click.option(
+    "--balance-classes/--no-balance-classes",
+    "-B/-N",
+    help="""If set, then balances weights of the random sampler during
+    training, so that samples from all sample classes are picked picked
+    equitably.  It also sets the training (and validation) losses to account
+    for the populations of each class.""",
+    required=True,
+    show_default=True,
+    default=True,
+    cls=ResourceOption,
+)
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def train(
     model,
@@ -238,20 +209,18 @@ def train(
     batch_size,
     batch_chunk_count,
     drop_incomplete_batch,
-    criterion,
-    criterion_valid,
     datamodule,
     checkpoint_period,
-    accelerator,
+    device,
     cache_samples,
     seed,
     parallel,
-    normalization,
     monitoring_interval,
     resume_from,
+    balance_classes,
     **_,
-):
-    """Trains an CNN to perform tuberculosis detection.
+) -> None:
+    """Trains an CNN to perform image classification.
 
     Training is performed for a configurable number of epochs, and
     generates at least a final_model.pth.  It may also generate a number
@@ -263,32 +232,56 @@ def train(
     import torch.cuda
     import torch.nn
 
-    from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss
+    from lightning.pytorch import seed_everything
+
+    from ..engine.device import DeviceManager
     from ..engine.trainer import run
+    from ..utils.checkpointer import get_checkpoint
+    from .utils import save_sh_command
 
+    save_sh_command(output_folder)
     seed_everything(seed)
 
     checkpoint_file = get_checkpoint(output_folder, resume_from)
 
-    datamodule.update_module_properties(
-        batch_size=batch_size,
-        batch_chunk_count=batch_chunk_count,
-        drop_incomplete_batch=drop_incomplete_batch,
-        cache_samples=cache_samples,
-        parallel=parallel,
-    )
+    # reset datamodule with user configurable options
+    datamodule.set_chunk_size(batch_size, batch_chunk_count)
+    datamodule.drop_incomplete_batch = drop_incomplete_batch
+    datamodule.cache_samples = cache_samples
+    datamodule.parallel = parallel
 
     datamodule.prepare_data()
     datamodule.setup(stage="fit")
 
-    reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid)
-    normalize_data(normalization, model, datamodule)
+    # Sets the model normalizer with the unaugmented-train-subset.
+    # this call may be a NOOP, if the model was pre-trained and expects
+    # different weights for the normalisation layer.
+    if hasattr(model, "set_normalizer"):
+        model.set_normalizer(datamodule.unshuffled_train_dataloader())
+    else:
+        logger.warning(
+            f"Model {model.name} has no 'set_normalizer' method. Skipping."
+        )
+
+    # If asked, rebalances the loss criterion based on the relative proportion
+    # of class examples available in the training set.  Also affects the
+    # validation loss if a validation set is available on the data module.
+    if balance_classes:
+        logger.info("Applying datamodule train sampler balancing...")
+        datamodule.balance_sampler_by_class = True
+        # logger.info("Applying train/valid loss balancing...")
+        # model.balance_losses_by_class(datamodule)
+    else:
+        logger.info(
+            "Skipping sample class/dataset ownership balancing on user request"
+        )
 
     arguments = {}
     arguments["max_epoch"] = epochs
     arguments["epoch"] = 0
 
-    # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit()
+    # We only load the checkpoint to get some information about its state. The
+    # actual loading of the model is done in trainer.fit()
     if checkpoint_file is not None:
         checkpoint = torch.load(checkpoint_file)
         arguments["epoch"] = checkpoint["epoch"]
@@ -300,7 +293,7 @@ def train(
         model=model,
         datamodule=datamodule,
         checkpoint_period=checkpoint_period,
-        accelerator=accelerator,
+        device_manager=DeviceManager(device),
         arguments=arguments,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
diff --git a/src/ptbench/scripts/utils.py b/src/ptbench/scripts/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5553a6ed1f4d79d1eba0ea6923c9821ce918dcd1
--- /dev/null
+++ b/src/ptbench/scripts/utils.py
@@ -0,0 +1,57 @@
+import importlib.metadata
+import logging
+import os
+import pathlib
+import sys
+import time
+
+logger = logging.getLogger(__name__)
+
+
+def save_sh_command(output_folder: str | pathlib.Path) -> None:
+    """Records command-line to reproduce this script.
+
+    This function can record the current command-line used to call the script
+    being run.  It creates an executable ``bash`` script setting up the current
+    working directory and activating a conda environment, if needed.  It
+    records further information on the date and time the script was run and the
+    version of the package.
+
+
+    Parameters
+    ----------
+
+    output_folder : str
+        Path leading to the directory where the commands to reproduce the current
+        run will be recorded. A subdirectory will be created each time this function
+        is called to match lightning's versioning convention for loggers.
+    """
+
+    if isinstance(output_folder, str):
+        output_folder = pathlib.Path(output_folder)
+
+    destfile = output_folder / "command.sh"
+
+    logger.info(f"Writing command-line for reproduction at '{destfile}'...")
+    os.makedirs(output_folder, exist_ok=True)
+
+    package = __name__.split(".", 1)[0]
+    version = importlib.metadata.version(package)
+
+    with destfile.open("w") as f:
+        f.write("#!/usr/bin/env sh\n")
+        f.write(f"# date: {time.asctime()}\n")
+        f.write(f"# version: {version} ({package})\n")
+        f.write(f"# platform: {sys.platform}\n")
+        f.write("\n")
+        args = []
+        for k in sys.argv:
+            if " " in k:
+                args.append(f'"{k}"')
+            else:
+                args.append(k)
+        if os.environ.get("CONDA_DEFAULT_ENV") is not None:
+            f.write(f"# conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
+        f.write(f"# cd {os.path.realpath(os.curdir)}\n")
+        f.write(" ".join(args) + "\n")
+    os.chmod(destfile, 0o755)
diff --git a/src/ptbench/utils/accelerator.py b/src/ptbench/utils/accelerator.py
deleted file mode 100644
index dcfa2f733e1d091c5bb9a4e5785ee47f8e49497c..0000000000000000000000000000000000000000
--- a/src/ptbench/utils/accelerator.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-import logging
-import os
-
-import torch
-
-logger = logging.getLogger(__name__)
-
-
-class AcceleratorProcessor:
-    """This class is used to convert the torch device naming convention to
-    lightning's device convention and vice versa.
-
-    It also sets the CUDA_VISIBLE_DEVICES if a gpu accelerator is used.
-    """
-
-    def __init__(self, name):
-        # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now.
-        self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"}
-
-        self.lightning_to_torch = {
-            v: k for k, v in self.torch_to_lightning.items()
-        }
-
-        self.valid_accelerators = set(
-            list(self.torch_to_lightning.keys())
-            + list(self.lightning_to_torch.keys())
-        )
-
-        self.accelerator, self.device = self._split_accelerator_name(name)
-
-        if self.accelerator not in self.valid_accelerators:
-            raise ValueError(f"Unknown accelerator {self.accelerator}")
-
-        # Keep lightning's convention by default
-        self.accelerator = self.to_lightning()
-        self.setup_accelerator()
-
-    def setup_accelerator(self):
-        """If a gpu accelerator is chosen, checks the CUDA_VISIBLE_DEVICES
-        environment variable exists or sets its value if specified."""
-        if self.accelerator == "gpu":
-            if not torch.cuda.is_available():
-                raise RuntimeError(
-                    f"CUDA is not currently available, but "
-                    f"you set accelerator to '{self.accelerator}'"
-                )
-
-            if self.device is not None:
-                os.environ["CUDA_VISIBLE_DEVICES"] = str(self.device[0])
-            else:
-                if os.environ.get("CUDA_VISIBLE_DEVICES") is None:
-                    raise ValueError(
-                        "Environment variable 'CUDA_VISIBLE_DEVICES' is not set."
-                        "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0"
-                    )
-        else:
-            # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu
-            pass
-
-        logger.info(
-            f"Accelerator set to {self.accelerator} and device to {self.device}"
-        )
-
-    def _split_accelerator_name(self, accelerator_name):
-        """Splits an accelerator string into accelerator and device components.
-
-        Parameters
-        ----------
-
-        accelerator_name: str
-            The accelerator (or device in pytorch convention) string (e.g. cuda:0)
-
-        Returns
-        -------
-
-        accelerator: str
-            The accelerator name
-        device: dict[int]
-            The selected devices
-        """
-
-        split_accelerator = accelerator_name.split(":")
-        accelerator = split_accelerator[0]
-
-        if len(split_accelerator) > 1:
-            device = split_accelerator[1]
-            device = [int(device)]
-        else:
-            device = None
-
-        return accelerator, device
-
-    def to_torch(self):
-        """Converts the accelerator string to torch convention.
-
-        Returns
-        -------
-
-        accelerator: str
-            The accelerator name in pytorch convention
-        """
-        if self.accelerator in self.lightning_to_torch:
-            return self.lightning_to_torch[self.accelerator]
-        elif self.accelerator in self.torch_to_lightning:
-            return self.accelerator
-        else:
-            raise ValueError("Unknown accelerator.")
-
-    def to_lightning(self):
-        """Converts the accelerator string to lightning convention.
-
-        Returns
-        -------
-
-        accelerator: str
-            The accelerator name in lightning convention
-        """
-        if self.accelerator in self.torch_to_lightning:
-            return self.torch_to_lightning[self.accelerator]
-        elif self.accelerator in self.lightning_to_torch:
-            return self.accelerator
-        else:
-            raise ValueError("Unknown accelerator.")
diff --git a/src/ptbench/utils/image.py b/src/ptbench/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..363a8309f4581dfdda42124c0562d3bf840904af
--- /dev/null
+++ b/src/ptbench/utils/image.py
@@ -0,0 +1,36 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import os
+
+from typing import Union
+
+import torch
+
+from PIL.Image import Image
+from torchvision import transforms
+
+
+def save_image(img: Union[torch.Tensor, Image], filepath: str) -> None:
+    """Saves a PIL image or a tensor as an image at the specified destination.
+
+    Parameters
+    ----------
+
+    img:
+        A torch.Tensor or PIL.Image to save
+
+    filepath:
+        The file in which to save the image. The format is inferred from the file extension, or defaults to png if not specified.
+    """
+
+    if isinstance(img, torch.Tensor):
+        img = transforms.ToPILImage()(img)
+
+    root, ext = os.path.splitext(filepath)
+
+    if len(ext) == 0:
+        filepath = filepath + ".png"
+
+    img.save(filepath)
diff --git a/src/ptbench/utils/resources.py b/src/ptbench/utils/resources.py
index ebad7794a6975932926ab2287801bb3cb052cd30..f7c7f6b958093a770bdc9c448b79e04a0b3dc704 100644
--- a/src/ptbench/utils/resources.py
+++ b/src/ptbench/utils/resources.py
@@ -6,11 +6,13 @@
 
 import logging
 import multiprocessing
+import multiprocessing.synchronize
 import os
 import queue
 import shutil
 import subprocess
 import time
+import typing
 
 import numpy
 import psutil
@@ -25,7 +27,9 @@ GB = float(2**30)
 """The number of bytes in a gigabyte."""
 
 
-def run_nvidia_smi(query, rename=None):
+def run_nvidia_smi(
+    query: typing.Sequence[str],
+) -> dict[str, str | float] | None:
     """Returns GPU information from query.
 
     For a comprehensive list of options and help, execute ``nvidia-smi
@@ -35,52 +39,43 @@ def run_nvidia_smi(query, rename=None):
     Parameters
     ----------
 
-    query : list
+    query
         A list of query strings as defined by ``nvidia-smi --help-query-gpu``
 
-    rename : :py:class:`list`, Optional
-        A list of keys to yield in the return value for each entry above.  It
-        gives you the opportunity to rewrite some key names for convenience.
-        This list, if provided, must be of the same length as ``query``.
-
 
     Returns
     -------
 
-    data : :py:class:`tuple`, None
-        An ordered dictionary (organized as 2-tuples) containing the queried
-        parameters (``rename`` versions).  If ``nvidia-smi`` is not available,
-        returns ``None``.  Percentage information is left alone,
-        memory information is transformed to gigabytes (floating-point).
+    data
+        A dictionary containing the queried parameters (``rename`` versions).
+        If ``nvidia-smi`` is not available, returns ``None``.  Percentage
+        information is left alone, memory information is transformed to
+        gigabytes (floating-point).
     """
-    if _nvidia_smi is not None:
-        if rename is None:
-            rename = query
-        else:
-            assert len(rename) == len(query)
-
-        # Get GPU information based on GPU ID.
-        values = subprocess.getoutput(
-            "%s --query-gpu=%s --format=csv,noheader --id=%s"
-            % (
-                _nvidia_smi,
-                ",".join(query),
-                os.environ.get("CUDA_VISIBLE_DEVICES"),
-            )
-        )
-        values = [k.strip() for k in values.split(",")]
-        t_values = []
-        for k in values:
-            if k.endswith("%"):
-                t_values.append(float(k[:-1].strip()))
-            elif k.endswith("MiB"):
-                t_values.append(float(k[:-3].strip()) / 1024)
-            else:
-                t_values.append(k)  # unchanged
-        return tuple(zip(rename, t_values))
-
-
-def gpu_constants():
+    if _nvidia_smi is None:
+        return None
+
+    # Gets GPU information, based on a GPU device if that is set. Returns
+    # ordered results.
+    query_str = (
+        f"{_nvidia_smi} --query-gpu={','.join(query)} --format=csv,noheader"
+    )
+    visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
+    if visible_devices:
+        query_str += f" --id={visible_devices}"
+    values = subprocess.getoutput(query_str)
+
+    retval: dict[str, str | float] = {}
+    for i, k in enumerate([k.strip() for k in values.split(",")]):
+        retval[query[i]] = k
+        if k.endswith("%"):
+            retval[query[i]] = float(k[:-1].strip())
+        elif k.endswith("MiB"):
+            retval[query[i]] = float(k[:-3].strip()) / 1024
+    return retval
+
+
+def gpu_constants() -> dict[str, str | int | float] | None:
     """Returns GPU (static) information using nvidia-smi.
 
     See :py:func:`run_nvidia_smi` for operational details.
@@ -90,21 +85,25 @@ def gpu_constants():
 
     data : :py:class:`tuple`, None
         If ``nvidia-smi`` is not available, returns ``None``, otherwise, we
-        return an ordered dictionary (organized as 2-tuples) containing the
-        following ``nvidia-smi`` query information:
+        return a dictionary containing the following ``nvidia-smi`` query
+        information, in this order:
 
         * ``gpu_name``, as ``gpu_name`` (:py:class:`str`)
         * ``driver_version``, as ``gpu_driver_version`` (:py:class:`str`)
         * ``memory.total``, as ``gpu_memory_total`` (transformed to gigabytes,
           :py:class:`float`)
     """
-    return run_nvidia_smi(
-        ("gpu_name", "driver_version", "memory.total"),
-        ("gpu_name", "gpu_driver_version", "gpu_memory_total_GB"),
-    )
+    retval = run_nvidia_smi(("gpu_name", "driver_version", "memory.total"))
+    if retval is None:
+        return retval
+
+    # else, just update with more generic names
+    retval["gpu_driver_version"] = retval.pop("driver_version")
+    retval["gpu_memory_used_GB"] = retval.pop("memory.total")
+    return retval
 
 
-def gpu_log():
+def gpu_log() -> dict[str, float] | None:
     """Returns GPU information about current non-static status using nvidia-
     smi.
 
@@ -113,10 +112,10 @@ def gpu_log():
     Returns
     -------
 
-    data : :py:class:`tuple`, None
+    data
         If ``nvidia-smi`` is not available, returns ``None``, otherwise, we
-        return an ordered dictionary (organized as 2-tuples) containing the
-        following ``nvidia-smi`` query information:
+        return a dictionary containing the following ``nvidia-smi`` query
+        information, in this order:
 
         * ``memory.used``, as ``gpu_memory_used`` (transformed to gigabytes,
           :py:class:`float`)
@@ -127,47 +126,41 @@ def gpu_log():
         * ``utilization.gpu``, as ``gpu_percent``,
           (:py:class:`float`, in percent)
     """
-    retval = run_nvidia_smi(
-        (
-            "memory.total",
-            "memory.used",
-            "memory.free",
-            "utilization.gpu",
-        ),
-        (
-            "gpu_memory_total_GB",
-            "gpu_memory_used_GB",
-            "gpu_memory_free_percent",
-            "gpu_usage_percent",
-        ),
-    )
 
-    # re-compose the output to generate expected values
-    return (
-        retval[1],  # gpu_memory_used
-        retval[2],  # gpu_memory_free
-        ("gpu_memory_percent", 100 * (retval[1][1] / retval[0][1])),
-        retval[3],  # gpu_percent
+    result = run_nvidia_smi(
+        ("memory.total", "memory.used", "memory.free", "utilization.gpu")
     )
 
+    if result is None:
+        return result
 
-def cpu_constants():
+    return {
+        "gpu_memory_used_GB": float(result["memory.used"]),
+        "gpu_memory_free_GB": float(result["memory.free"]),
+        "gpu_memory_percent": 100
+        * float(result["memory.used"])
+        / float(result["memory.total"]),
+        "gpu_percent": float(result["utilization.gpu"]),
+    }
+
+
+def cpu_constants() -> dict[str, int | float]:
     """Returns static CPU information about the current system.
 
     Returns
     -------
 
-    data : tuple
+    data
         An ordered dictionary (organized as 2-tuples) containing these entries:
 
         0. ``cpu_memory_total`` (:py:class:`float`): total memory available,
            in gigabytes
         1. ``cpu_count`` (:py:class:`int`): number of logical CPUs available
     """
-    return (
-        ("cpu_memory_total_GB", psutil.virtual_memory().total / GB),
-        ("cpu_count", psutil.cpu_count(logical=True)),
-    )
+    return {
+        "cpu_memory_total_GB": psutil.virtual_memory().total / GB,
+        "cpu_count": psutil.cpu_count(logical=True),
+    }
 
 
 class CPULogger:
@@ -176,24 +169,24 @@ class CPULogger:
     Parameters
     ----------
 
-    pid : :py:class:`int`, Optional
+    pid
         Process identifier of the main process (parent process) to observe
     """
 
-    def __init__(self, pid=None):
+    def __init__(self, pid: int | None = None):
         this = psutil.Process(pid=pid)
         self.cluster = [this] + this.children(recursive=True)
         # touch cpu_percent() at least once for all processes in the cluster
         [k.cpu_percent(interval=None) for k in self.cluster]
 
-    def log(self):
-        """Returns current process cluster information.
+    def log(self) -> dict[str, int | float]:
+        """Returns current process cluster iformation.
 
         Returns
         -------
 
-        data : tuple
-            An ordered dictionary (organized as 2-tuples) containing these entries:
+        data
+            An ordered dictionary containing these entries:
 
             0. ``cpu_memory_used`` (:py:class:`float`): total memory used from
                the system, in gigabytes
@@ -244,14 +237,14 @@ class CPULogger:
                 # it is too late to update any intermediate list
                 # at this point, but ensures to update counts later on
                 gone.add(k)
-        return (
-            ("cpu_memory_used_GB", psutil.virtual_memory().used / GB),
-            ("cpu_rss_GB", sum([k.rss for k in memory_info]) / GB),
-            ("cpu_vms_GB", sum([k.vms for k in memory_info]) / GB),
-            ("cpu_percent", sum(cpu_percent)),
-            ("cpu_processes", len(self.cluster) - len(gone)),
-            ("cpu_open_files", sum(open_files)),
-        )
+        return {
+            "cpu_memory_used_GB": psutil.virtual_memory().used / GB,
+            "cpu_rss_GB": sum([k.rss for k in memory_info]) / GB,
+            "cpu_vms_GB": sum([k.vms for k in memory_info]) / GB,
+            "cpu_percent": sum(cpu_percent),
+            "cpu_processes": len(self.cluster) - len(gone),
+            "cpu_open_files": sum(open_files),
+        }
 
 
 class _InformationGatherer:
@@ -260,73 +253,85 @@ class _InformationGatherer:
     Parameters
     ----------
 
-    has_gpu : bool
+    has_gpu
         A flag indicating if we have a GPU installed on the platform or not
 
-    main_pid : int
+    main_pid
         The main process identifier to monitor
 
-    logger : logging.Logger
+    logger
         A logger to be used for logging messages
     """
 
-    def __init__(self, has_gpu, main_pid, logger):
+    def __init__(
+        self, has_gpu: bool, main_pid: int | None, logger: logging.Logger
+    ):
+        self.logger: logging.Logger = logger
         self.cpu_logger = CPULogger(main_pid)
-        self.keys = [k[0] for k in self.cpu_logger.log()]
-        self.cpu_keys_len = len(self.keys)
-        self.has_gpu = has_gpu
-        self.logger = logger
+        keys: list[str] = list(self.cpu_logger.log().keys())
+        self.has_gpu: bool = has_gpu
         if self.has_gpu:
-            self.keys += [k[0] for k in gpu_log()]
-        self.data = [[] for _ in self.keys]
+            example = gpu_log()
+            if example is not None:
+                keys += list(example.keys())
+        self.data: dict[str, list[int | float]] = {k: [] for k in keys}
 
-    def acc(self):
+    def acc(self) -> None:
         """Accumulates another measurement."""
-        for i, k in enumerate(self.cpu_logger.log()):
-            self.data[i].append(k[1])
+        for k, v in self.cpu_logger.log().items():
+            self.data[k].append(v)
         if self.has_gpu:
-            for i, k in enumerate(gpu_log()):
-                self.data[i + self.cpu_keys_len].append(k[1])
+            sample = gpu_log()
+            if sample is not None:
+                for k, v in sample.items():
+                    self.data[k].append(v)
 
-    def clear(self):
+    def clear(self) -> None:
         """Clears accumulated data."""
-        self.data = [[] for _ in self.keys]
+        for k in self.data.keys():
+            self.data[k] = []
 
-    def summary(self):
+    def summary(self) -> dict[str, list[int | float]]:
         """Returns the current data."""
-        if len(self.data[0]) == 0:
+        if len(next(iter(self.data.values()))) == 0:
             self.logger.error("CPU/GPU logger was not able to collect any data")
-        retval = []
-        for k, values in zip(self.keys, self.data):
-            retval.append((k, values))
-        return tuple(retval)
+        return self.data
 
 
 def _monitor_worker(
-    interval, has_gpu, main_pid, stop, summary_event, queue, logging_level
+    interval: int | float,
+    has_gpu: bool,
+    main_pid: int,
+    stop: multiprocessing.synchronize.Event,
+    summary_event: multiprocessing.synchronize.Event,
+    queue: queue.Queue,
+    logging_level: int,
 ):
     """A monitoring worker that measures resources and returns lists.
 
     Parameters
     ==========
 
-    interval : int, float
+    interval
         Number of seconds to wait between each measurement (maybe a floating
         point number as accepted by :py:func:`time.sleep`)
 
-    has_gpu : bool
+    has_gpu
         A flag indicating if we have a GPU installed on the platform or not
 
-    main_pid : int
+    main_pid
         The main process identifier to monitor
 
-    stop : :py:class:`multiprocessing.Event`
-        Indicates if we should continue running or stop
+    stop
+        Event that indicates if we should continue running or stop
 
-    queue : :py:class:`queue.Queue`
+    summary_event
+        Event that indicates if we should produce a summary
+
+    queue
         A queue, to send monitoring information back to the spawner
 
-    logging_level: int
+    logging_level
         The logging level to use for logging from launched processes
     """
     logger = multiprocessing.log_to_stderr(level=logging_level)
@@ -343,9 +348,9 @@ def _monitor_worker(
 
             time.sleep(interval)
         except Exception:
-            logger.warning(
-                "Iterative CPU/GPU logging did not work properly " "this once",
-                exc_info=True,
+            logger.exception(
+                "Iterative CPU/GPU logging did not work properly."
+                " Exception follows.  Retrying..."
             )
             time.sleep(0.5)  # wait half a second, and try again!
 
@@ -356,27 +361,35 @@ class ResourceMonitor:
     Parameters
     ----------
 
-    interval : int, float
+    interval
         Number of seconds to wait between each measurement (maybe a floating
         point number as accepted by :py:func:`time.sleep`)
 
-    has_gpu : bool
+    has_gpu
         A flag indicating if we have a GPU installed on the platform or not
 
-    main_pid : int
+    main_pid
         The main process identifier to monitor
 
-    logging_level: int
+    logging_level
         The logging level to use for logging from launched processes
     """
 
-    def __init__(self, interval, has_gpu, main_pid, logging_level):
+    def __init__(
+        self,
+        interval: int | float,
+        has_gpu: bool,
+        main_pid: int,
+        logging_level: int,
+    ):
         self.interval = interval
         self.has_gpu = has_gpu
         self.main_pid = main_pid
         self.stop_event = multiprocessing.Event()
         self.summary_event = multiprocessing.Event()
-        self.q = multiprocessing.Queue()
+        self.q: multiprocessing.Queue[
+            dict[str, list[int | float]]
+        ] = multiprocessing.Queue()
         self.logging_level = logging_level
 
         self.monitor = multiprocessing.Process(
@@ -393,23 +406,23 @@ class ResourceMonitor:
             ),
         )
 
-        self.data = None
-
-    @staticmethod
-    def monitored_keys(has_gpu):
-        return _InformationGatherer(has_gpu, None, logger).keys
+        self.data: dict[str, int | float] | None = None
 
-    def __enter__(self):
+    def __enter__(self) -> None:
         """Starts the monitoring process."""
         self.monitor.start()
 
-    def trigger_summary(self):
+    def checkpoint(self) -> None:
+        """Forces the monitoring process to yield data and clear the internal
+        accumlator."""
         self.summary_event.set()
 
         try:
-            data = self.q.get(timeout=2 * self.interval)
+            data: dict[str, list[int | float]] = self.q.get(
+                timeout=2 * self.interval
+            )
         except queue.Empty:
-            logger.warn(
+            logger.warning(
                 f"CPU/GPU resource monitor did not provide anything when "
                 f"joined (even after a {2*self.interval}-second timeout - "
                 f"this is normally due to exceptions on the monitoring process. "
@@ -417,19 +430,18 @@ class ResourceMonitor:
             )
             self.data = None
         else:
-            # summarize the returned data by creating means
-            summary = []
-            for k, values in data:
+            # summarize the returned data by creating averages
+            self.data = {}
+            for k, values in data.items():
                 if values:
                     if k in ("cpu_processes", "cpu_open_files"):
-                        summary.append((k, numpy.max(values)))
+                        self.data[k] = numpy.max(values)
                     else:
-                        summary.append((k, numpy.mean(values)))
+                        self.data[k] = float(numpy.mean(values))
                 else:
-                    summary.append((k, 0.0))
-            self.data = tuple(summary)
+                    self.data[k] = 0.0
 
-    def __exit__(self, *exc):
+    def __exit__(self, *_) -> None:
         """Stops the monitoring process and returns the summary of
         observations."""