diff --git a/src/ptbench/configs/datasets/__init__.py b/src/ptbench/configs/datasets/__init__.py
deleted file mode 100644
index 88ac7a33e7051a69ff6d8ecfdb55aeb94b13b72f..0000000000000000000000000000000000000000
--- a/src/ptbench/configs/datasets/__init__.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-import torch
-
-from torchvision.transforms import RandomRotation
-
-"""Standard configurations for dataset setup"""
-
-RANDOM_ROTATION = [RandomRotation(15)]
-"""Shared data augmentation based on random rotation only."""
-
-
-def get_samples_weights(dataset):
-    """Compute the weights of all the samples of the dataset to balance it
-    using the sampler of the dataloader.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the weights to balance each class in the dataset and the
-    datasets themselves if we have a ConcatDataset.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    samples_weights : :py:class:`torch.Tensor`
-        the weights for all the samples in the dataset given as input
-    """
-    samples_weights = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            # Weighting only for binary labels
-            if isinstance(ds._samples[0].label, int):
-                # Groundtruth
-                targets = []
-                for s in ds._samples:
-                    targets.append(s.label)
-                targets = torch.tensor(targets)
-
-                # Count number of samples per class
-                class_sample_count = torch.tensor(
-                    [
-                        (targets == t).sum()
-                        for t in torch.unique(targets, sorted=True)
-                    ]
-                )
-
-                weight = 1.0 / class_sample_count.float()
-
-                samples_weights.append(
-                    torch.tensor([weight[t] for t in targets])
-                )
-
-            else:
-                # We only weight to sample equally from each dataset
-                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
-
-        # Concatenate sample weights from all the datasets
-        samples_weights = torch.cat(samples_weights)
-
-    else:
-        # Weighting only for binary labels
-        if isinstance(dataset._samples[0].label, int):
-            # Groundtruth
-            targets = []
-            for s in dataset._samples:
-                targets.append(s.label)
-            targets = torch.tensor(targets)
-
-            # Count number of samples per class
-            class_sample_count = torch.tensor(
-                [
-                    (targets == t).sum()
-                    for t in torch.unique(targets, sorted=True)
-                ]
-            )
-
-            weight = 1.0 / class_sample_count.float()
-
-            samples_weights = torch.tensor([weight[t] for t in targets])
-
-        else:
-            # Equal weights for non-binary labels
-            samples_weights = torch.ones(len(dataset._samples))
-
-    return samples_weights
-
-
-def get_positive_weights(dataset):
-    """Compute the positive weights of each class of the dataset to balance the
-    BCEWithLogitsLoss criterion.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the positive weights of each class to use them to have
-    a balanced loss.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    positive_weights : :py:class:`torch.Tensor`
-        the positive weight of each class in the dataset given as input
-    """
-    targets = []
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            for s in ds._samples:
-                targets.append(s.label)
-
-    else:
-        for s in dataset._samples:
-            targets.append(s.label)
-
-    targets = torch.tensor(targets)
-
-    # Binary labels
-    if len(list(targets.shape)) == 1:
-        class_sample_count = [
-            float((targets == t).sum().item())
-            for t in torch.unique(targets, sorted=True)
-        ]
-
-        # Divide negatives by positives
-        positive_weights = torch.tensor(
-            [class_sample_count[0] / class_sample_count[1]]
-        ).reshape(-1)
-
-    # Multiclass labels
-    else:
-        class_sample_count = torch.sum(targets, dim=0)
-        negative_class_sample_count = (
-            torch.full((targets.size()[1],), float(targets.size()[0]))
-            - class_sample_count
-        )
-
-        positive_weights = negative_class_sample_count / (
-            class_sample_count + negative_class_sample_count
-        )
-
-    return positive_weights
diff --git a/src/ptbench/configs/datasets/shenzhen/__init__.py b/src/ptbench/configs/datasets/shenzhen/__init__.py
deleted file mode 100644
index 84b9088ea60cbbf9ddee2fdf1bfc14203beda01f..0000000000000000000000000000000000000000
--- a/src/ptbench/configs/datasets/shenzhen/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
deleted file mode 100644
index 09aa54748d382122fd1f67a3b3e1871f3f0aa132..0000000000000000000000000000000000000000
--- a/src/ptbench/configs/datasets/shenzhen/default.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Shenzhen dataset for TB detection (default protocol)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.shenzhen` for dataset details
-"""
-
-from clapper.logging import setup
-
-from ....data import return_subsets
-from ....data.base_datamodule import BaseDataModule
-from ....data.dataset import JSONProtocol
-from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        cache_samples=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-        self.cache_samples = cache_samples
-        self.has_setup_fit = False
-        self.has_setup_predict = False
-
-    def setup(self, stage: str):
-        if self.cache_samples:
-            logger.info(
-                "Argument cache_samples set to True. Samples will be loaded in memory."
-            )
-            samples_loader = _cached_loader
-        else:
-            logger.info(
-                "Argument cache_samples set to False. Samples will be loaded at runtime."
-            )
-            samples_loader = _delayed_loader
-
-        json_protocol = JSONProtocol(
-            protocols=_protocols,
-            fieldnames=("data", "label"),
-            loader=samples_loader,
-        )
-
-        if not self.has_setup_fit and stage == "fit":
-            (
-                self.train_dataset,
-                self.validation_dataset,
-                self.extra_validation_datasets,
-            ) = return_subsets(json_protocol, "default", stage)
-            self.has_setup_fit = True
-
-        if not self.has_setup_predict and stage == "predict":
-            (
-                self.train_dataset,
-                self.validation_dataset,
-                self.extra_validation_datasets,
-            ) = return_subsets(json_protocol, "default", stage)
-
-            self.has_setup_predict = True
-
-
-datamodule = DefaultModule()
diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index cda0540fe0d7ea01c67df571cbeeecaa8199c53f..3ee0b92164b5531b65049b94e71b01b07e2ad27e 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -13,9 +13,7 @@ Reference: [PASA-2019]_
 
 from torch import empty
 from torch.nn import BCEWithLogitsLoss
-from torchvision import transforms
 
-from ...data.transforms import ElasticDeformation
 from ...models.pasa import PASA
 
 # config
@@ -28,9 +26,5 @@ optimizer = "Adam"
 criterion = BCEWithLogitsLoss(pos_weight=empty(1))
 criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
 
-train_transforms = transforms.Compose([ElasticDeformation(p=0.8)])
-
 # model
-model = PASA(
-    train_transforms, criterion, criterion_valid, optimizer, optimizer_configs
-)
+model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py
index 682d5d1ad8e88ae4f8dd72dccb888200bc3d161a..84b9088ea60cbbf9ddee2fdf1bfc14203beda01f 100644
--- a/src/ptbench/data/__init__.py
+++ b/src/ptbench/data/__init__.py
@@ -1,356 +1,3 @@
-"""Data manipulation and raw dataset definitions."""
-
-import random
-
-import torch
-
-from clapper.logging import setup
-
-from .utils import SampleListDataset
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
-    """Creates a new data set, applying transforms.
-
-    .. note::
-
-       This is a convenience function for our own dataset definitions inside
-       this module, guaranteeting homogenity between dataset definitions
-       provided in this package.  It assumes certain strategies for data
-       augmentation that may not be translatable to other applications.
-
-
-    Parameters
-    ----------
-
-    samples : list
-        List of delayed samples
-
-    transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-
-    prefixes : list
-        A list of data augmentation operations that needs to be applied
-        **before** the transforms above
-
-    suffixes : list
-        A list of data augmentation operations that needs to be applied
-        **after** the transforms above
-
-
-    Returns
-    -------
-
-    subset : :py:class:`ptbench.data.utils.SampleListDataset`
-        A pre-formatted dataset that can be fed to one of our engines
-    """
-    from .utils import SampleListDataset as wrapper
-
-    return wrapper(samples, prefixes + transforms + suffixes)
-
-
-def make_dataset(
-    subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
-):
-    """Creates a new configuration dataset from a list of dictionaries and
-    transforms.
-
-    This function takes as input a list of dictionaries as those that can be
-    returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
-    mapping protocol names (such as ``train``, ``dev`` and ``test``) to
-    :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
-    transforms, and returns a dictionary applying
-    :py:class:`ptbench.data.utils.SampleListDataset` to these
-    lists, and our standard data augmentation if a ``train`` set exists.
-
-    For example, if ``subsets`` is composed of two sets named ``train`` and
-    ``test``, this function will yield a dictionary with the following entries:
-
-    * ``__train__``: Wraps the ``train`` subset, includes data augmentation
-      (note: datasets with names starting with ``_`` (underscore) are excluded
-      from prediction and evaluation by default, as they contain data
-      augmentation transformations.)
-    * ``train``: Wraps the ``train`` subset, **without** data augmentation
-    * ``test``: Wraps the ``test`` subset, **without** data augmentation
-
-    .. note::
-
-       This is a convenience function for our own dataset definitions inside
-       this module, guaranteeting homogenity between dataset definitions
-       provided in this package.  It assumes certain strategies for data
-       augmentation that may not be translatable to other applications.
-
-
-    Parameters
-    ----------
-
-    subsets : list
-        A list of dictionaries that contains the delayed sample lists
-        for a number of named lists. The subsets will be aggregated in one
-        final subset. If one of the keys is ``train``, our standard dataset
-        augmentation transforms are appended to the definition of that subset.
-        All other subsets remain un-augmented.
-
-    transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-
-    t_transforms : list
-        A list of transforms that needs to be applied to the train samples
-
-    post_transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-        after all the other transforms
-
-
-    Returns
-    -------
-
-    dataset : dict
-        A pre-formatted dataset that can be fed to one of our engines. It maps
-        string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
-    """
-
-    retval = {}
-
-    if len(subsets_groups) == 1:
-        subsets = subsets_groups[0]
-    else:
-        # If multiple subsets groups: aggregation
-        aggregated_subsets = {}
-        for subsets in subsets_groups:
-            for key in subsets.keys():
-                if key in aggregated_subsets:
-                    aggregated_subsets[key] += subsets[key]
-                    # Shuffle if data comes from multiple datasets
-                    random.shuffle(aggregated_subsets[key])
-                else:
-                    aggregated_subsets[key] = subsets[key]
-        subsets = aggregated_subsets
-
-    # Add post_transforms after t_transforms for the train set
-    t_transforms += post_transforms
-
-    for key in subsets.keys():
-        retval[key] = make_subset(
-            subsets[key], transforms=transforms, suffixes=post_transforms
-        )
-        if key == "train":
-            retval["__train__"] = make_subset(
-                subsets[key], transforms=transforms, suffixes=(t_transforms)
-            )
-        if key == "validation":
-            # also use it for validation during training
-            retval["__valid__"] = retval[key]
-
-    if (
-        ("__train__" in retval)
-        and ("train" in retval)
-        and ("__valid__" not in retval)
-    ):
-        # if the dataset does not have a validation set, we use the unaugmented
-        # training set as validation set
-        retval["__valid__"] = retval["train"]
-
-    return retval
-
-
-def get_samples_weights(dataset):
-    """Compute the weights of all the samples of the dataset to balance it
-    using the sampler of the dataloader.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the weights to balance each class in the dataset and the
-    datasets themselves if we have a ConcatDataset.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    samples_weights : :py:class:`torch.Tensor`
-        the weights for all the samples in the dataset given as input
-    """
-    samples_weights = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            # Weighting only for binary labels
-            if isinstance(ds._samples[0].label, int):
-                # Groundtruth
-                targets = []
-                for s in ds._samples:
-                    targets.append(s.label)
-                targets = torch.tensor(targets)
-
-                # Count number of samples per class
-                class_sample_count = torch.tensor(
-                    [
-                        (targets == t).sum()
-                        for t in torch.unique(targets, sorted=True)
-                    ]
-                )
-
-                weight = 1.0 / class_sample_count.float()
-
-                samples_weights.append(
-                    torch.tensor([weight[t] for t in targets])
-                )
-
-            else:
-                # We only weight to sample equally from each dataset
-                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
-
-        # Concatenate sample weights from all the datasets
-        samples_weights = torch.cat(samples_weights)
-
-    else:
-        # Weighting only for binary labels
-        if isinstance(dataset._samples[0].label, int):
-            # Groundtruth
-            targets = []
-            for s in dataset._samples:
-                targets.append(s.label)
-            targets = torch.tensor(targets)
-
-            # Count number of samples per class
-            class_sample_count = torch.tensor(
-                [
-                    (targets == t).sum()
-                    for t in torch.unique(targets, sorted=True)
-                ]
-            )
-
-            weight = 1.0 / class_sample_count.float()
-
-            samples_weights = torch.tensor([weight[t] for t in targets])
-
-        else:
-            # Equal weights for non-binary labels
-            samples_weights = torch.ones(len(dataset._samples))
-
-    return samples_weights
-
-
-def get_positive_weights(dataset):
-    """Compute the positive weights of each class of the dataset to balance the
-    BCEWithLogitsLoss criterion.
-
-    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
-    and computes the positive weights of each class to use them to have
-    a balanced loss.
-
-
-    Parameters
-    ----------
-
-    dataset : torch.utils.data.dataset.Dataset
-        An instance of torch.utils.data.dataset.Dataset
-        ConcatDataset are supported
-
-
-    Returns
-    -------
-
-    positive_weights : :py:class:`torch.Tensor`
-        the positive weight of each class in the dataset given as input
-    """
-    targets = []
-
-    if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            for s in ds._samples:
-                targets.append(s.label)
-
-    else:
-        for s in dataset._samples:
-            targets.append(s.label)
-
-    targets = torch.tensor(targets)
-
-    # Binary labels
-    if len(list(targets.shape)) == 1:
-        class_sample_count = [
-            float((targets == t).sum().item())
-            for t in torch.unique(targets, sorted=True)
-        ]
-
-        # Divide negatives by positives
-        positive_weights = torch.tensor(
-            [class_sample_count[0] / class_sample_count[1]]
-        ).reshape(-1)
-
-    # Multiclass labels
-    else:
-        class_sample_count = torch.sum(targets, dim=0)
-        negative_class_sample_count = (
-            torch.full((targets.size()[1],), float(targets.size()[0]))
-            - class_sample_count
-        )
-
-        positive_weights = negative_class_sample_count / (
-            class_sample_count + negative_class_sample_count
-        )
-
-    return positive_weights
-
-
-def return_subsets(dataset, protocol, stage):
-    train_set = None
-    valid_set = None
-    extra_valid_sets = None
-
-    subsets = dataset.subsets(protocol)
-
-    def get_train_subset():
-        if "train" in subsets.keys():
-            nonlocal train_set
-            train_set = SampleListDataset(subsets["train"], [])
-
-    def get_valid_subset():
-        if "validation" in subsets.keys():
-            nonlocal valid_set
-            valid_set = SampleListDataset(subsets["validation"], [])
-        else:
-            logger.warning(
-                "No validation dataset found, using training set instead."
-            )
-            if train_set is None:
-                get_train_subset()
-
-            valid_set = train_set
-
-    def get_extra_valid_subset():
-        if "__extra_valid__" in subsets.keys():
-            if not isinstance(subsets["__extra_valid__"], list):
-                raise RuntimeError(
-                    f"If present, dataset['__extra_valid__'] must be a list, "
-                    f"but you passed a {type(subsets['__extra_valid__'])}, "
-                    f"which is invalid."
-                )
-            logger.info(
-                f"Found {len(subsets['__extra_valid__'])} extra validation "
-                f"set(s) to be tracked during training"
-            )
-            logger.info(
-                "Extra validation sets are NOT used for model checkpointing!"
-            )
-            nonlocal extra_valid_sets
-            extra_valid_sets = SampleListDataset(subsets["__extra_valid__"], [])
-
-    if stage == "fit":
-        get_train_subset()
-        get_valid_subset()
-        get_extra_valid_subset()
-
-        return train_set, valid_set, extra_valid_sets
-    else:
-        raise ValueError(f"Stage {stage} is unknown.")
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
index 1c51a1055dc226efc72292a3e439041c67a6d37b..5cfa1620d14a186ec034821b95c3d1a2a0fdbb8b 100644
--- a/src/ptbench/data/base_datamodule.py
+++ b/src/ptbench/data/base_datamodule.py
@@ -2,13 +2,17 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import multiprocessing
+import sys
+
 import lightning as pl
 import torch
 
 from clapper.logging import setup
 from torch.utils.data import DataLoader, WeightedRandomSampler
+from torchvision import transforms
 
-from ptbench.configs.datasets import get_samples_weights
+from .dataset import get_samples_weights
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -16,24 +20,30 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 class BaseDataModule(pl.LightningDataModule):
     def __init__(
         self,
-        train_batch_size=1,
-        predict_batch_size=1,
+        batch_size=1,
         batch_chunk_count=1,
         drop_incomplete_batch=False,
-        multiproc_kwargs={},
+        parallel=-1,
     ):
         super().__init__()
 
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
+        self.batch_size = batch_size
         self.batch_chunk_count = batch_chunk_count
 
+        self.train_dataset = None
+        self.validation_dataset = None
+        self.extra_validation_datasets = None
+
+        self._raw_data_transforms = []
+        self._model_transforms = []
+        self._augmentation_transforms = []
+
         self.drop_incomplete_batch = drop_incomplete_batch
         self.pin_memory = (
             torch.cuda.is_available()
         )  # should only be true if GPU available and using it
 
-        self.multiproc_kwargs = multiproc_kwargs
+        self.parallel = parallel
 
     def setup(self, stage: str):
         # Implemented by user
@@ -43,29 +53,32 @@ class BaseDataModule(pl.LightningDataModule):
     def train_dataloader(self):
         train_samples_weights = get_samples_weights(self.train_dataset)
 
+        multiproc_kwargs = self._setup_multiproc(self.parallel)
         train_sampler = WeightedRandomSampler(
             train_samples_weights, len(train_samples_weights), replacement=True
         )
 
         return DataLoader(
             self.train_dataset,
-            batch_size=self.compute_chunk_size(self.train_batch_size),
+            batch_size=self._compute_chunk_size(self.batch_size),
             drop_last=self.drop_incomplete_batch,
             pin_memory=self.pin_memory,
             sampler=train_sampler,
-            **self.multiproc_kwargs,
+            **multiproc_kwargs,
         )
 
     def val_dataloader(self):
         loaders_dict = {}
 
+        multiproc_kwargs = self._setup_multiproc(self.parallel)
+
         val_loader = DataLoader(
             dataset=self.validation_dataset,
-            batch_size=self.compute_chunk_size(self.train_batch_size),
+            batch_size=self._compute_chunk_size(self.batch_size),
             shuffle=False,
             drop_last=False,
             pin_memory=self.pin_memory,
-            **self.multiproc_kwargs,
+            **multiproc_kwargs,
         )
 
         loaders_dict["validation_loader"] = val_loader
@@ -74,11 +87,11 @@ class BaseDataModule(pl.LightningDataModule):
             for set_idx, extra_set in enumerate(self.extra_validation_datasets):
                 extra_val_loader = DataLoader(
                     dataset=extra_set,
-                    batch_size=self.train_batch_size,
+                    batch_size=self._compute_chunk_size(self.batch_size),
                     shuffle=False,
                     drop_last=False,
                     pin_memory=self.pin_memory,
-                    **self.multiproc_kwargs,
+                    **multiproc_kwargs,
                 )
 
                 loaders_dict[
@@ -96,7 +109,7 @@ class BaseDataModule(pl.LightningDataModule):
 
         return loaders_dict
 
-    def compute_chunk_size(self, batch_size):
+    def _compute_chunk_size(self, batch_size):
         batch_chunk_size = batch_size
         if batch_size % self.batch_chunk_count != 0:
             # batch_size must be divisible by batch_chunk_count.
@@ -109,6 +122,56 @@ class BaseDataModule(pl.LightningDataModule):
 
         return batch_chunk_size
 
+    def _setup_multiproc(self, parallel):
+        multiproc_kwargs = dict()
+        if parallel < 0:
+            multiproc_kwargs["num_workers"] = 0
+        else:
+            multiproc_kwargs["num_workers"] = (
+                parallel or multiprocessing.cpu_count()
+            )
+
+        if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
+            multiproc_kwargs[
+                "multiprocessing_context"
+            ] = multiprocessing.get_context("spawn")
+
+        return multiproc_kwargs
+
+    def _build_transforms(self, is_train):
+        all_transforms = self.raw_data_transforms + self.model_transforms
+
+        if is_train:
+            all_transforms = all_transforms + self.augmentation_transforms
+
+        # all_transforms.append(transforms.ToTensor())
+
+        return transforms.Compose(all_transforms)
+
+    @property
+    def raw_data_transforms(self):
+        return self._raw_data_transforms
+
+    @raw_data_transforms.setter
+    def raw_data_transforms(self, transforms):
+        self._raw_data_transforms = transforms
+
+    @property
+    def model_transforms(self):
+        return self._model_transforms
+
+    @model_transforms.setter
+    def model_transforms(self, transforms):
+        self._model_transforms = transforms
+
+    @property
+    def augmentation_transforms(self):
+        return self._augmentation_transforms
+
+    @augmentation_transforms.setter
+    def augmentation_transforms(self, transforms):
+        self._augmentation_transforms = transforms
+
 
 def get_dataset_from_module(module, stage, **module_args):
     """Instantiates a DataModule and retrieves the corresponding dataset.
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index b568e78a225dc515c5e0b53d7d6b9a5b64f95cd1..fad353a13cfd2b38279e94193a565dc5abbb337f 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -7,16 +7,11 @@ import json
 import logging
 import os
 import pathlib
-import random
 
 import torch
 
-from torchvision.transforms import RandomRotation
 from tqdm import tqdm
 
-RANDOM_ROTATION = [RandomRotation(15)]
-"""Shared data augmentation based on random rotation only."""
-
 logger = logging.getLogger(__name__)
 
 
@@ -75,7 +70,7 @@ class JSONProtocol:
         * ``data``: which contains the data associated witht this sample
     """
 
-    def __init__(self, protocols, fieldnames, loader, post_transforms=[]):
+    def __init__(self, protocols, fieldnames):
         if isinstance(protocols, dict):
             self._protocols = protocols
         else:
@@ -86,8 +81,6 @@ class JSONProtocol:
                 for k in protocols
             }
         self.fieldnames = fieldnames
-        self._loader = loader
-        self.post_transforms = post_transforms
 
     def check(self, limit=0):
         """For each protocol, check if all data can be correctly accessed.
@@ -174,11 +167,7 @@ class JSONProtocol:
             logger.info(f"Loading subset {subset} samples.")
 
             retval[subset] = [
-                self._loader(
-                    dict(protocol=protocol, subset=subset, order=n),
-                    dict(zip(self.fieldnames, k)),
-                    self.post_transforms,
-                )
+                dict(zip(self.fieldnames, k))
                 for n, k in enumerate(tqdm(samples))
             ]
 
@@ -328,149 +317,98 @@ class CSVDataset:
         ]
 
 
-def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
-    """Creates a new data set, applying transforms.
-
-    .. note::
-
-       This is a convenience function for our own dataset definitions inside
-       this module, guaranteeting homogenity between dataset definitions
-       provided in this package.  It assumes certain strategies for data
-       augmentation that may not be translatable to other applications.
-
-
-    Parameters
-    ----------
-
-    samples : list
-        List of delayed samples
-
-    transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-
-    prefixes : list
-        A list of data augmentation operations that needs to be applied
-        **before** the transforms above
-
-    suffixes : list
-        A list of data augmentation operations that needs to be applied
-        **after** the transforms above
-
-
-    Returns
-    -------
-
-    subset : :py:class:`ptbench.data.utils.SampleListDataset`
-        A pre-formatted dataset that can be fed to one of our engines
-    """
-    from .utils import SampleListDataset as wrapper
-
-    return wrapper(samples, prefixes + transforms + suffixes)
-
-
-def make_dataset(
-    subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
-):
-    """Creates a new configuration dataset from a list of dictionaries and
-    transforms.
-
-    This function takes as input a list of dictionaries as those that can be
-    returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
-    mapping protocol names (such as ``train``, ``dev`` and ``test``) to
-    :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
-    transforms, and returns a dictionary applying
-    :py:class:`ptbench.data.utils.SampleListDataset` to these
-    lists, and our standard data augmentation if a ``train`` set exists.
-
-    For example, if ``subsets`` is composed of two sets named ``train`` and
-    ``test``, this function will yield a dictionary with the following entries:
-
-    * ``__train__``: Wraps the ``train`` subset, includes data augmentation
-      (note: datasets with names starting with ``_`` (underscore) are excluded
-      from prediction and evaluation by default, as they contain data
-      augmentation transformations.)
-    * ``train``: Wraps the ``train`` subset, **without** data augmentation
-    * ``test``: Wraps the ``test`` subset, **without** data augmentation
-
-    .. note::
-
-       This is a convenience function for our own dataset definitions inside
-       this module, guaranteeting homogenity between dataset definitions
-       provided in this package.  It assumes certain strategies for data
-       augmentation that may not be translatable to other applications.
-
-
-    Parameters
-    ----------
+class CachedDataset(torch.utils.data.Dataset):
+    def __init__(
+        self, json_protocol, protocol, subset, raw_data_loader, transforms
+    ):
+        self.json_protocol = json_protocol
+        self.subset = subset
+        self.raw_data_loader = raw_data_loader
+        self.transforms = transforms
+
+        self._samples = json_protocol.subsets(protocol)[self.subset]
+        # Dict entry with relative path to files, used during prediction
+        for s in self._samples:
+            s["name"] = s["data"]
+
+        logger.info(f"Caching {self.subset} samples")
+        for sample in tqdm(self._samples):
+            sample["data"] = self.transforms(
+                self.raw_data_loader(sample["data"])
+            )
 
-    subsets : list
-        A list of dictionaries that contains the delayed sample lists
-        for a number of named lists. The subsets will be aggregated in one
-        final subset. If one of the keys is ``train``, our standard dataset
-        augmentation transforms are appended to the definition of that subset.
-        All other subsets remain un-augmented.
+    def __getitem__(self, idx):
+        return self._samples[idx]
 
-    transforms : list
-        A list of transforms that needs to be applied to all samples in the set
+    def __len__(self):
+        return len(self._samples)
 
-    t_transforms : list
-        A list of transforms that needs to be applied to the train samples
 
-    post_transforms : list
-        A list of transforms that needs to be applied to all samples in the set
-        after all the other transforms
+class RuntimeDataset(torch.utils.data.Dataset):
+    def __init__(
+        self, json_protocol, protocol, subset, raw_data_loader, transforms
+    ):
+        self.json_protocol = json_protocol
+        self.subset = subset
+        self.raw_data_loader = raw_data_loader
+        self.transforms = transforms
+
+        self._samples = json_protocol.subsets(protocol)[self.subset]
+        # Dict entry with relative path to files
+        for s in self._samples:
+            s["name"] = s["data"]
+
+    def __getitem__(self, idx):
+        sample = self._samples[idx].copy()
+        sample["data"] = self.transforms(self.raw_data_loader(sample["data"]))
+        return sample
+
+    def __len__(self):
+        return len(self._samples)
+
+
+class TBDataset(torch.utils.data.Dataset):
+    def __init__(
+        self,
+        json_protocol,
+        protocol,
+        subset,
+        raw_data_loader,
+        transforms,
+        cache_samples=False,
+    ):
+        self.json_protocol = json_protocol
+        self.subset = subset
+        self.raw_data_loader = raw_data_loader
+        self.transforms = transforms
 
+        self.cache_samples = cache_samples
 
-    Returns
-    -------
+        self._samples = json_protocol.subsets(protocol)[self.subset]
 
-    dataset : dict
-        A pre-formatted dataset that can be fed to one of our engines. It maps
-        string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
-    """
+        # Dict entry with relative path to files
+        for s in self._samples:
+            s["name"] = s["data"]
 
-    retval = {}
+        if self.cache_samples:
+            logger.info(f"Caching {self.subset} samples")
+            for sample in tqdm(self._samples):
+                sample["data"] = self.transforms(
+                    self.raw_data_loader(sample["data"])
+                )
 
-    if len(subsets_groups) == 1:
-        subsets = subsets_groups[0]
-    else:
-        # If multiple subsets groups: aggregation
-        aggregated_subsets = {}
-        for subsets in subsets_groups:
-            for key in subsets.keys():
-                if key in aggregated_subsets:
-                    aggregated_subsets[key] += subsets[key]
-                    # Shuffle if data comes from multiple datasets
-                    random.shuffle(aggregated_subsets[key])
-                else:
-                    aggregated_subsets[key] = subsets[key]
-        subsets = aggregated_subsets
-
-    # Add post_transforms after t_transforms for the train set
-    t_transforms += post_transforms
-
-    for key in subsets.keys():
-        retval[key] = make_subset(
-            subsets[key], transforms=transforms, suffixes=post_transforms
-        )
-        if key == "train":
-            retval["__train__"] = make_subset(
-                subsets[key], transforms=transforms, suffixes=(t_transforms)
+    def __getitem__(self, idx):
+        if self.cache_samples:
+            return self._samples[idx]
+        else:
+            sample = self._samples[idx].copy()
+            sample["data"] = self.transforms(
+                self.raw_data_loader(sample["data"])
             )
-        if key == "validation":
-            # also use it for validation during training
-            retval["__valid__"] = retval[key]
-
-    if (
-        ("__train__" in retval)
-        and ("train" in retval)
-        and ("__valid__" not in retval)
-    ):
-        # if the dataset does not have a validation set, we use the unaugmented
-        # training set as validation set
-        retval["__valid__"] = retval["train"]
+            return sample
 
-    return retval
+    def __len__(self):
+        return len(self._samples)
 
 
 def get_samples_weights(dataset):
@@ -501,11 +439,11 @@ def get_samples_weights(dataset):
     if isinstance(dataset, torch.utils.data.ConcatDataset):
         for ds in dataset.datasets:
             # Weighting only for binary labels
-            if isinstance(ds._samples[0].label, int):
+            if isinstance(ds._samples[0]["label"], int):
                 # Groundtruth
                 targets = []
                 for s in ds._samples:
-                    targets.append(s.label)
+                    targets.append(s["label"])
                 targets = torch.tensor(targets)
 
                 # Count number of samples per class
@@ -531,11 +469,11 @@ def get_samples_weights(dataset):
 
     else:
         # Weighting only for binary labels
-        if isinstance(dataset._samples[0].label, int):
+        if isinstance(dataset._samples[0]["label"], int):
             # Groundtruth
             targets = []
             for s in dataset._samples:
-                targets.append(s.label)
+                targets.append(s["label"])
             targets = torch.tensor(targets)
 
             # Count number of samples per class
@@ -585,11 +523,11 @@ def get_positive_weights(dataset):
     if isinstance(dataset, torch.utils.data.ConcatDataset):
         for ds in dataset.datasets:
             for s in ds._samples:
-                targets.append(s.label)
+                targets.append(s["label"])
 
     else:
         for s in dataset._samples:
-            targets.append(s.label)
+            targets.append(s["label"])
 
     targets = torch.tensor(targets)
 
@@ -618,3 +556,75 @@ def get_positive_weights(dataset):
         )
 
     return positive_weights
+
+
+def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
+    from torch.nn import BCEWithLogitsLoss
+
+    datamodule.prepare_data()
+    datamodule.setup(stage="fit")
+
+    train_dataset = datamodule.train_dataset
+    validation_dataset = datamodule.validation_dataset
+
+    # Redefine a weighted criterion if possible
+    if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
+        positive_weights = get_positive_weights(train_dataset)
+        model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
+    else:
+        logger.warning("Weighted criterion not supported")
+
+    if validation_dataset is not None:
+        # Redefine a weighted valid criterion if possible
+        if (
+            isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
+            or criterion_valid is None
+        ):
+            positive_weights = get_positive_weights(validation_dataset)
+            model.hparams.criterion_valid = BCEWithLogitsLoss(
+                pos_weight=positive_weights
+            )
+        else:
+            logger.warning("Weighted valid criterion not supported")
+
+
+def normalize_data(normalization, model, datamodule):
+    from torch.utils.data import DataLoader
+
+    datamodule.prepare_data()
+    datamodule.setup(stage="fit")
+
+    train_dataset = datamodule.train_dataset
+
+    # Create z-normalization model layer if needed
+    if normalization == "imagenet":
+        model.normalizer.set_mean_std(
+            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
+        )
+        logger.info("Z-normalization with ImageNet mean and std")
+    elif normalization == "current":
+        # Compute mean/std of current train subset
+        temp_dl = DataLoader(
+            dataset=train_dataset, batch_size=len(train_dataset)
+        )
+
+        data = next(iter(temp_dl))
+        mean = data[1].mean(dim=[0, 2, 3])
+        std = data[1].std(dim=[0, 2, 3])
+
+        model.normalizer.set_mean_std(mean, std)
+
+        # Format mean and std for logging
+        mean = str(
+            [
+                round(x, 3)
+                for x in ((mean * 10**3).round() / (10**3)).tolist()
+            ]
+        )
+        std = str(
+            [
+                round(x, 3)
+                for x in ((std * 10**3).round() / (10**3)).tolist()
+            ]
+        )
+        logger.info(f"Z-normalization with mean {mean} and std {std}")
diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/loader.py
index a11aefee77becbbbb07e15a25d5404594c16999a..d6e86ed06dd04abb230830ebc8b13e2ea9720cfa 100644
--- a/src/ptbench/data/loader.py
+++ b/src/ptbench/data/loader.py
@@ -5,13 +5,8 @@
 
 """Data loading code."""
 
-
-import functools
-
 import PIL.Image
 
-from .sample import DelayedSample, Sample
-
 
 def load_pil(path):
     """Loads a sample data.
@@ -68,78 +63,3 @@ def load_pil_rgb(path):
         A PIL image in RGB mode
     """
     return load_pil(path).convert("RGB")
-
-
-def make_cached(sample, loader, additional_transforms=[], key=None):
-    return Sample(
-        loader(sample, additional_transforms),
-        key=key or sample["data"],
-        label=sample["label"],
-    )
-
-
-def make_delayed(sample, loader, additional_transforms=[], key=None):
-    """Returns a delayed-loading Sample object.
-
-    Parameters
-    ----------
-
-    sample : dict
-        A dictionary that maps field names to sample data values (e.g. paths)
-
-    loader : object
-        A function that inputs ``sample`` dictionaries and returns the loaded
-        data.
-
-    key : str
-        A unique key identifier for this sample.  If not provided, assumes
-        ``sample`` is a dictionary with a ``data`` entry and uses its path as
-        key.
-
-
-    Returns
-    -------
-
-    sample : ptbench.data.sample.DelayedSample
-        In which ``key`` is as provided and ``data`` can be accessed to trigger
-        sample loading.
-    """
-    return DelayedSample(
-        functools.partial(loader, sample, additional_transforms),
-        key=key or sample["data"],
-        label=sample["label"],
-    )
-
-
-def make_delayed_bbox(sample, loader, key=None):
-    """Returns a delayed-loading Sample object.
-
-    Parameters
-    ----------
-
-    sample : dict
-        A dictionary that maps field names to sample data values (e.g. paths)
-
-    loader : object
-        A function that inputs ``sample`` dictionaries and returns the loaded
-        data.
-
-    key : str
-        A unique key identifier for this sample.  If not provided, assumes
-        ``sample`` is a dictionary with a ``data`` entry and uses its path as
-        key.
-
-
-    Returns
-    -------
-
-    sample : ptbench.data.sample.DelayedSample
-        In which ``key`` is as provided and ``data`` can be accessed to trigger
-        sample loading.
-    """
-    return DelayedSample(
-        functools.partial(loader, sample),
-        key=key or sample["data"],
-        label=sample["label"],
-        bboxes=sample["bboxes"],
-    )
diff --git a/src/ptbench/data/sample.py b/src/ptbench/data/sample.py
deleted file mode 100644
index 8d5102af091977fb4d746ca2f83de7fdd351cd91..0000000000000000000000000000000000000000
--- a/src/ptbench/data/sample.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Base definition of sample."""
-
-
-def _copy_attributes(s, d):
-    """Copies attributes from a dictionary to self."""
-    s.__dict__.update(
-        {k: v for k, v in d.items() if k not in ("data", "load", "samples")}
-    )
-
-
-class DelayedSample:
-    """Representation of sample that can be loaded via a callable.
-
-    The optional ``**kwargs`` argument allows you to attach more attributes to
-    this sample instance.
-
-
-    Parameters
-    ----------
-
-        load : object
-            A python function that can be called parameterlessly, to load the
-            sample in question from whatever medium
-
-        parent : :py:class:`DelayedSample`, :py:class:`Sample`, None
-            If passed, consider this as a parent of this sample, to copy
-            information
-
-        kwargs : dict
-            Further attributes of this sample, to be stored and eventually
-            transmitted to transformed versions of the sample
-    """
-
-    def __init__(self, load, parent=None, **kwargs):
-        self.load = load
-        if parent is not None:
-            _copy_attributes(self, parent.__dict__)
-        _copy_attributes(self, kwargs)
-
-    @property
-    def data(self):
-        """Loads the data from the disk file."""
-        return self.load()
-
-
-class Sample:
-    """Representation of sample that is sufficient for the blocks in this
-    module.
-
-    Each sample must have the following attributes:
-
-        * attribute ``data``: Contains the data for this sample
-
-
-    Parameters
-    ----------
-
-        data : object
-            Object representing the data to initialize this sample with.
-
-        parent : object
-            A parent object from which to inherit all other attributes (except
-            ``data``)
-    """
-
-    def __init__(self, data, parent=None, **kwargs):
-        self.data = data
-        if parent is not None:
-            _copy_attributes(self, parent.__dict__)
-        _copy_attributes(self, kwargs)
diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py
index d284b1b28905594e0c9c7fd20cb31f7c566468e3..54eb632684a83e16ac28481f8499137f594b662a 100644
--- a/src/ptbench/data/shenzhen/__init__.py
+++ b/src/ptbench/data/shenzhen/__init__.py
@@ -23,11 +23,9 @@ import importlib.resources
 import os
 
 from clapper.logging import setup
-from torchvision import transforms
 
 from ...utils.rc import load_rc
-from ..loader import load_pil_baw, make_cached, make_delayed
-from ..transforms import RemoveBlackBorders
+from ..loader import load_pil_baw
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -48,34 +46,8 @@ _protocols = [
 
 _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
 
-_resize_size = 512
-_cc_size = 512
-
-_data_transforms = [
-    RemoveBlackBorders(),
-    transforms.Resize(_resize_size),
-    transforms.CenterCrop(_cc_size),
-    transforms.ToTensor(),
-]
-
-
-def _raw_data_loader(sample, additional_transforms=[]):
-    raw_data = load_pil_baw(os.path.join(_datadir, sample["data"]))
-
-    base_transforms = transforms.Compose(
-        _data_transforms + additional_transforms
-    )
-    return dict(
-        data=base_transforms(raw_data),
-        label=sample["label"],
-    )
-
-
-def _cached_loader(context, sample, additional_transforms=[]):
-    return make_cached(sample, _raw_data_loader, additional_transforms)
 
+def _raw_data_loader(img_path):
+    raw_data = load_pil_baw(os.path.join(_datadir, img_path))
 
-def _delayed_loader(context, sample, additional_transforms=[]):
-    # "context" is ignored in this case - database is homogeneous
-    # we returned delayed samples to avoid loading all images at once
-    return make_delayed(sample, _raw_data_loader, additional_transforms)
+    return raw_data
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..f471d426e20e25c3517050b68f54650ce3074314
--- /dev/null
+++ b/src/ptbench/data/shenzhen/default.py
@@ -0,0 +1,103 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Shenzhen dataset for TB detection (default protocol)
+
+* Split reference: first 64% of TB and healthy CXR for "train" 16% for
+* "validation", 20% for "test"
+* This configuration resolution: 512 x 512 (default)
+* See :py:mod:`ptbench.data.shenzhen` for dataset details
+"""
+
+from clapper.logging import setup
+from torchvision import transforms
+
+from ..base_datamodule import BaseDataModule
+from ..dataset import JSONProtocol, TBDataset
+from ..shenzhen import _protocols, _raw_data_loader
+from ..transforms import ElasticDeformation, RemoveBlackBorders
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+class DefaultModule(BaseDataModule):
+    def __init__(
+        self,
+        batch_size=1,
+        batch_chunk_count=1,
+        drop_incomplete_batch=False,
+        cache_samples=False,
+        parallel=-1,
+    ):
+        super().__init__(
+            batch_size=batch_size,
+            drop_incomplete_batch=drop_incomplete_batch,
+            batch_chunk_count=batch_chunk_count,
+            parallel=parallel,
+        )
+
+        self._cache_samples = cache_samples
+        self._has_setup_fit = False
+        self._has_setup_predict = False
+        self._protocol = "default"
+
+        self.raw_data_transforms = [
+            RemoveBlackBorders(),
+            transforms.Resize(512),
+            transforms.CenterCrop(512),
+            transforms.ToTensor(),
+        ]
+
+        self.model_transforms = []
+
+        self.augmentation_transforms = [ElasticDeformation(p=0.8)]
+
+    def setup(self, stage: str):
+        json_protocol = JSONProtocol(
+            protocols=_protocols,
+            fieldnames=("data", "label"),
+        )
+
+        if not self._has_setup_fit and stage == "fit":
+            self.train_dataset = TBDataset(
+                json_protocol,
+                self._protocol,
+                "train",
+                _raw_data_loader,
+                self._build_transforms(is_train=True),
+                cache_samples=self._cache_samples,
+            )
+            self.validation_dataset = TBDataset(
+                json_protocol,
+                self._protocol,
+                "validation",
+                _raw_data_loader,
+                self._build_transforms(is_train=False),
+                cache_samples=self._cache_samples,
+            )
+
+            self._has_setup_fit = True
+
+        if not self._has_setup_predict and stage == "predict":
+            self.train_dataset = TBDataset(
+                json_protocol,
+                self._protocol,
+                "train",
+                _raw_data_loader,
+                self._build_transforms(is_train=False),
+                cache_samples=self._cache_samples,
+            )
+            self.validation_dataset = TBDataset(
+                json_protocol,
+                self._protocol,
+                "validation",
+                _raw_data_loader,
+                self._build_transforms(is_train=False),
+                cache_samples=self._cache_samples,
+            )
+
+            self._has_setup_predict = True
+
+
+datamodule = DefaultModule
diff --git a/src/ptbench/configs/datasets/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py
similarity index 90%
rename from src/ptbench/configs/datasets/shenzhen/rgb.py
rename to src/ptbench/data/shenzhen/rgb.py
index 7cf77faa1406e6b34b0db06415002c2e23de49ab..2d93c07edb4c0824caa8149737ec42783966ad4c 100644
--- a/src/ptbench/configs/datasets/shenzhen/rgb.py
+++ b/src/ptbench/data/shenzhen/rgb.py
@@ -10,7 +10,6 @@
 """
 
 from clapper.logging import setup
-from torchvision import transforms
 
 from ....data import return_subsets
 from ....data.base_datamodule import BaseDataModule
@@ -28,6 +27,9 @@ class DefaultModule(BaseDataModule):
         drop_incomplete_batch=False,
         cache_samples=False,
         multiproc_kwargs=None,
+        data_transforms=[],
+        model_transforms=[],
+        train_transforms=[],
     ):
         super().__init__(
             train_batch_size=train_batch_size,
@@ -39,11 +41,15 @@ class DefaultModule(BaseDataModule):
         self.cache_samples = cache_samples
         self.has_setup_fit = False
 
-        self.post_transforms = [
+        self.data_transforms = data_transforms
+        self.model_transforms = model_transforms
+        self.train_transforms = train_transforms
+
+        """[
             transforms.ToPILImage(),
             transforms.Lambda(lambda x: x.convert("RGB")),
             transforms.ToTensor(),
-        ]
+        ]"""
 
     def setup(self, stage: str):
         if self.cache_samples:
diff --git a/src/ptbench/data/utils.py b/src/ptbench/data/utils.py
deleted file mode 100644
index bc5bdbd4b5fbda35361018e40479199435209fe0..0000000000000000000000000000000000000000
--- a/src/ptbench/data/utils.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-
-"""Common utilities."""
-
-import numpy as np
-import PIL
-import torch
-import torch.utils.data
-
-from torchvision.transforms import Compose, ToTensor
-
-
-class SampleListDataset(torch.utils.data.Dataset):
-    """PyTorch dataset wrapper around Sample lists.
-
-    A transform object can be passed that will be applied to the image, ground
-    truth and mask (if present).
-
-    It supports indexing such that dataset[i] can be used to get ith sample.
-
-    Parameters
-    ----------
-
-    samples : list
-        A list of :py:class:`ptbench.data.sample.Sample` objects
-
-    transforms : :py:class:`list`, Optional
-        a list of transformations to be applied to **both** image and
-        ground-truth data.  Notice a last transform
-        (:py:class:`torchvision.transforms.transforms.ToTensor`) is always
-        applied - you do not need to add that.
-    """
-
-    def __init__(self, samples, transforms=[]):
-        self._samples = samples
-        self.transforms = transforms
-
-    @property
-    def transforms(self):
-        return self._transforms.transforms[:-1]
-
-    @transforms.setter
-    def transforms(self, data):
-        if any([isinstance(t, ToTensor) for t in data]):
-            self._transforms = Compose(data)
-        else:
-            self._transforms = Compose(data + [ToTensor()])
-
-    def copy(self, transforms=None):
-        """Returns a deep copy of itself, optionally resetting transforms.
-
-        Parameters
-        ----------
-
-        transforms : :py:class:`list`, Optional
-            An optional list of transforms to set in the copy.  If not
-            specified, use ``self.transforms``.
-        """
-        return SampleListDataset(self._samples, transforms or self.transforms)
-
-    def random_permute(self, feature):
-        """Randomly permute feature values from all samples.
-
-        Useful for permutation feature importance computation
-
-        Parameters
-        ----------
-
-        feature : int
-            The position of the feature
-        """
-        feature_values = np.zeros(len(self))
-
-        for k, s in enumerate(self._samples):
-            features = s.data["data"]
-            if isinstance(features, list):
-                feature_values[k] = features[feature]
-
-        np.random.shuffle(feature_values)
-
-        for k, s in enumerate(self._samples):
-            features = s.data["data"]
-            features[feature] = feature_values[k]
-
-    def __len__(self):
-        """
-
-        Returns
-        -------
-
-        size : int
-            size of the dataset
-
-        """
-        return len(self._samples)
-
-    def __getitem__(self, key):
-        """
-
-        Parameters
-        ----------
-
-        key : int, slice
-
-        Returns
-        -------
-
-        sample : list
-            The sample data: ``[key, image, label]``
-
-        """
-        if isinstance(key, slice):
-            return [self[k] for k in range(*key.indices(len(self)))]
-        else:  # we try it as an int
-            item = data = self._samples[key]
-            if not isinstance(data, dict):
-                key = item.key
-                data = item.data  # triggers data loading
-
-            retval = data["data"]
-
-            if self._transforms and isinstance(retval, PIL.Image.Image):
-                retval = self._transforms(retval)
-            elif isinstance(retval, list):
-                retval = torch.FloatTensor(retval)
-
-            if "label" in data:
-                if isinstance(data["label"], list):
-                    return [key, retval, torch.FloatTensor(data["label"])]
-                else:
-                    return [key, retval, data["label"]]
-
-            return [item.key, retval]
diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index dc037af0679d3493a718a26dc393a61d00659a27..4535b368bca54daf97fa350b91516a126eff319a 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -13,7 +13,7 @@ from .callbacks import PredictionsWriter
 logger = logging.getLogger(__name__)
 
 
-def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
+def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
     """Runs inference on input data, outputs csv files with predictions.
 
     Parameters
@@ -73,6 +73,6 @@ def run(model, data_loader, name, accelerator, output_folder, grad_cams=False):
         ],
     )
 
-    all_predictions = trainer.predict(model, data_loader)
+    all_predictions = trainer.predict(model, datamodule)
 
     return all_predictions
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index ed8c4b30a9412ebd2afb78c0b21f5a0ecb26fd91..d9e86b7055ed1ca3a06d807ab5d02b175b2f0cef 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -18,7 +18,6 @@ class PASA(pl.LightningModule):
 
     def __init__(
         self,
-        train_transforms,
         criterion,
         criterion_valid,
         optimizer,
@@ -26,14 +25,12 @@ class PASA(pl.LightningModule):
     ):
         super().__init__()
 
-        self.save_hyperparameters(ignore=["train_transforms"])
+        self.save_hyperparameters()
 
         self.name = "pasa"
 
         self.normalizer = TorchVisionNormalizer(nb_channels=1)
 
-        self.train_transforms = train_transforms
-
         # First convolution block
         self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
         self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
@@ -131,11 +128,8 @@ class PASA(pl.LightningModule):
         return x
 
     def training_step(self, batch, batch_idx):
-        images = batch[1]
-        labels = batch[2]
-
-        for img in images:
-            img = self.train_transforms(img)
+        images = batch["data"]
+        labels = batch["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -152,8 +146,8 @@ class PASA(pl.LightningModule):
         return {"loss": training_loss}
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch[1]
-        labels = batch[2]
+        images = batch["data"]
+        labels = batch["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -175,8 +169,8 @@ class PASA(pl.LightningModule):
             return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch[0]
-        images = batch[1]
+        names = batch["names"]
+        images = batch["data"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
@@ -186,7 +180,18 @@ class PASA(pl.LightningModule):
         if isinstance(outputs, list):
             outputs = outputs[-1]
 
-        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+        return {
+            f"dataloader_{dataloader_idx}_predictions": (
+                names[0],
+                torch.flatten(probabilities),
+                torch.flatten(batch[2]),
+            )
+        }
+
+    def on_predict_epoch_end(self):
+        # Need to cache predictions in the predict step, then reorder by key
+        # Clear prediction dict
+        raise NotImplementedError
 
     def configure_optimizers(self):
         # Dynamically instantiates the optimizer given the configs
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 3606c9f9d2c0179f171acc4b3dcba7987242047e..6b37e1a109623290e10d6c79009cc69bb403ab05 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -260,104 +260,34 @@ def train(
     procedure in case it stops abruptly.
     """
 
-    import multiprocessing
-    import sys
-
     import torch.cuda
     import torch.nn
 
-    from torch.nn import BCEWithLogitsLoss
-    from torch.utils.data import DataLoader
-
-    from ..configs.datasets import get_positive_weights
+    from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss
     from ..engine.trainer import run
 
     seed_everything(seed)
 
-    # PyTorch dataloader
-    multiproc_kwargs = dict()
-    if parallel < 0:
-        multiproc_kwargs["num_workers"] = 0
-    else:
-        multiproc_kwargs["num_workers"] = (
-            parallel or multiprocessing.cpu_count()
-        )
-
-    if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
-        multiproc_kwargs[
-            "multiprocessing_context"
-        ] = multiprocessing.get_context("spawn")
+    checkpoint_file = get_checkpoint(output_folder, resume_from)
 
-    datamodule.train_batch_size = batch_size
-    datamodule.batch_chunk_count = batch_chunk_count
-    datamodule.multiproc_kwargs = multiproc_kwargs
+    datamodule = datamodule(
+        batch_size=batch_size,
+        batch_chunk_count=batch_chunk_count,
+        drop_incomplete_batch=drop_incomplete_batch,
+        cache_samples=cache_samples,
+        parallel=parallel,
+    )
 
-    # Manually calling these as we need to access some values to reweight the criterion
     datamodule.prepare_data()
     datamodule.setup(stage="fit")
 
-    train_dataset = datamodule.train_dataset
-    validation_dataset = datamodule.validation_dataset
-
-    # Redefine a weighted criterion if possible
-    if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
-        positive_weights = get_positive_weights(train_dataset)
-        model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
-    else:
-        logger.warning("Weighted criterion not supported")
-
-    if validation_dataset is not None:
-        # Redefine a weighted valid criterion if possible
-        if (
-            isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
-            or criterion_valid is None
-        ):
-            positive_weights = get_positive_weights(validation_dataset)
-            model.hparams.criterion_valid = BCEWithLogitsLoss(
-                pos_weight=positive_weights
-            )
-        else:
-            logger.warning("Weighted valid criterion not supported")
-
-    # Create z-normalization model layer if needed
-    if normalization == "imagenet":
-        model.normalizer.set_mean_std(
-            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
-        )
-        logger.info("Z-normalization with ImageNet mean and std")
-    elif normalization == "current":
-        # Compute mean/std of current train subset
-        temp_dl = DataLoader(
-            dataset=train_dataset, batch_size=len(train_dataset)
-        )
-
-        data = next(iter(temp_dl))
-        mean = data[1].mean(dim=[0, 2, 3])
-        std = data[1].std(dim=[0, 2, 3])
-
-        model.normalizer.set_mean_std(mean, std)
-
-        # Format mean and std for logging
-        mean = str(
-            [
-                round(x, 3)
-                for x in ((mean * 10**3).round() / (10**3)).tolist()
-            ]
-        )
-        std = str(
-            [
-                round(x, 3)
-                for x in ((std * 10**3).round() / (10**3)).tolist()
-            ]
-        )
-        logger.info(f"Z-normalization with mean {mean} and std {std}")
+    reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid)
+    normalize_data(normalization, model, datamodule)
 
     arguments = {}
     arguments["max_epoch"] = epochs
     arguments["epoch"] = 0
 
-    checkpoint_file = get_checkpoint(output_folder, resume_from)
-
     # We only load the checkpoint to get some information about its state. The actual loading of the model is done in trainer.fit()
     if checkpoint_file is not None:
         checkpoint = torch.load(checkpoint_file)