diff --git a/pyproject.toml b/pyproject.toml
index b2a52e5b215be2743b5267a6ed51c194c7546146..f22d566a950755597aac41eb69394dec6795c3cc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -117,7 +117,7 @@ montgomery_rs_f7 = "ptbench.configs.datasets.montgomery_RS.fold_7"
 montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8"
 montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9"
 # shenzhen dataset (and cross-validation folds)
-shenzhen = "ptbench.configs.datasets.shenzhen.default"
+shenzhen = "ptbench.data.shenzhen.default"
 shenzhen_rgb = "ptbench.configs.datasets.shenzhen.rgb"
 shenzhen_f0 = "ptbench.configs.datasets.shenzhen.fold_0"
 shenzhen_f1 = "ptbench.configs.datasets.shenzhen.fold_1"
diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
deleted file mode 100644
index 635c9ed9b7ed23a255ba9f5e5f79828de17b7f17..0000000000000000000000000000000000000000
--- a/src/ptbench/configs/datasets/shenzhen/default.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Shenzhen dataset for TB detection (default protocol)
-
-* Split reference: first 64% of TB and healthy CXR for "train" 16% for
-* "validation", 20% for "test"
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.shenzhen` for dataset details
-"""
-
-from . import _maker
-
-dataset = _maker("default")
diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/base_datamodule.py
similarity index 63%
rename from src/ptbench/data/datamodule.py
rename to src/ptbench/data/base_datamodule.py
index d3c03d76578b624f4d38a06e255501f91e2fab49..fb0970f0b2c1ebb1cfcae291abd626428f556bf6 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/base_datamodule.py
@@ -13,10 +13,9 @@ from ptbench.configs.datasets import get_samples_weights
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-class DataModule(pl.LightningDataModule):
+class BaseDataModule(pl.LightningDataModule):
     def __init__(
         self,
-        dataset,
         train_batch_size=1,
         predict_batch_size=1,
         drop_incomplete_batch=False,
@@ -24,8 +23,6 @@ class DataModule(pl.LightningDataModule):
     ):
         super().__init__()
 
-        self.dataset = dataset
-
         self.train_batch_size = train_batch_size
         self.predict_batch_size = predict_batch_size
 
@@ -37,37 +34,9 @@ class DataModule(pl.LightningDataModule):
         self.multiproc_kwargs = multiproc_kwargs
 
     def setup(self, stage: str):
-        if stage == "fit":
-            if "__train__" in self.dataset:
-                logger.info("Found (dedicated) '__train__' set for training")
-                self.train_dataset = self.dataset["__train__"]
-            else:
-                self.train_dataset = self.dataset["train"]
-
-            if "__valid__" in self.dataset:
-                logger.info("Found (dedicated) '__valid__' set for validation")
-                self.validation_dataset = self.dataset["__valid__"]
-
-            if "__extra_valid__" in self.dataset:
-                if not isinstance(self.dataset["__extra_valid__"], list):
-                    raise RuntimeError(
-                        f"If present, dataset['__extra_valid__'] must be a list, "
-                        f"but you passed a {type(self.dataset['__extra_valid__'])}, "
-                        f"which is invalid."
-                    )
-                logger.info(
-                    f"Found {len(self.dataset['__extra_valid__'])} extra validation "
-                    f"set(s) to be tracked during training"
-                )
-                logger.info(
-                    "Extra validation sets are NOT used for model checkpointing!"
-                )
-                self.extra_validation_datasets = self.dataset["__extra_valid__"]
-            else:
-                self.extra_validation_datasets = None
-
-        if stage == "predict":
-            self.predict_dataset = self.dataset
+        # Implemented by user
+        # Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
+        raise NotImplementedError
 
     def train_dataloader(self):
         train_samples_weights = get_samples_weights(self.train_dataset)
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index 995f227c2e44fea1bc77898d5e009a505108932e..1c425562633f346c7eac2a3af9f667520e785520 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -7,6 +7,14 @@ import json
 import logging
 import os
 import pathlib
+import random
+
+import torch
+
+from torchvision.transforms import RandomRotation
+
+RANDOM_ROTATION = [RandomRotation(15)]
+"""Shared data augmentation based on random rotation only."""
 
 logger = logging.getLogger(__name__)
 
@@ -313,3 +321,295 @@ class CSVDataset:
             )
             for n, k in enumerate(samples)
         ]
+
+
+def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
+    """Creates a new data set, applying transforms.
+
+    .. note::
+
+       This is a convenience function for our own dataset definitions inside
+       this module, guaranteeting homogenity between dataset definitions
+       provided in this package.  It assumes certain strategies for data
+       augmentation that may not be translatable to other applications.
+
+
+    Parameters
+    ----------
+
+    samples : list
+        List of delayed samples
+
+    transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+
+    prefixes : list
+        A list of data augmentation operations that needs to be applied
+        **before** the transforms above
+
+    suffixes : list
+        A list of data augmentation operations that needs to be applied
+        **after** the transforms above
+
+
+    Returns
+    -------
+
+    subset : :py:class:`ptbench.data.utils.SampleListDataset`
+        A pre-formatted dataset that can be fed to one of our engines
+    """
+    from .utils import SampleListDataset as wrapper
+
+    return wrapper(samples, prefixes + transforms + suffixes)
+
+
+def make_dataset(
+    subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
+):
+    """Creates a new configuration dataset from a list of dictionaries and
+    transforms.
+
+    This function takes as input a list of dictionaries as those that can be
+    returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
+    mapping protocol names (such as ``train``, ``dev`` and ``test``) to
+    :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
+    transforms, and returns a dictionary applying
+    :py:class:`ptbench.data.utils.SampleListDataset` to these
+    lists, and our standard data augmentation if a ``train`` set exists.
+
+    For example, if ``subsets`` is composed of two sets named ``train`` and
+    ``test``, this function will yield a dictionary with the following entries:
+
+    * ``__train__``: Wraps the ``train`` subset, includes data augmentation
+      (note: datasets with names starting with ``_`` (underscore) are excluded
+      from prediction and evaluation by default, as they contain data
+      augmentation transformations.)
+    * ``train``: Wraps the ``train`` subset, **without** data augmentation
+    * ``test``: Wraps the ``test`` subset, **without** data augmentation
+
+    .. note::
+
+       This is a convenience function for our own dataset definitions inside
+       this module, guaranteeting homogenity between dataset definitions
+       provided in this package.  It assumes certain strategies for data
+       augmentation that may not be translatable to other applications.
+
+
+    Parameters
+    ----------
+
+    subsets : list
+        A list of dictionaries that contains the delayed sample lists
+        for a number of named lists. The subsets will be aggregated in one
+        final subset. If one of the keys is ``train``, our standard dataset
+        augmentation transforms are appended to the definition of that subset.
+        All other subsets remain un-augmented.
+
+    transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+
+    t_transforms : list
+        A list of transforms that needs to be applied to the train samples
+
+    post_transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+        after all the other transforms
+
+
+    Returns
+    -------
+
+    dataset : dict
+        A pre-formatted dataset that can be fed to one of our engines. It maps
+        string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
+    """
+
+    retval = {}
+
+    if len(subsets_groups) == 1:
+        subsets = subsets_groups[0]
+    else:
+        # If multiple subsets groups: aggregation
+        aggregated_subsets = {}
+        for subsets in subsets_groups:
+            for key in subsets.keys():
+                if key in aggregated_subsets:
+                    aggregated_subsets[key] += subsets[key]
+                    # Shuffle if data comes from multiple datasets
+                    random.shuffle(aggregated_subsets[key])
+                else:
+                    aggregated_subsets[key] = subsets[key]
+        subsets = aggregated_subsets
+
+    # Add post_transforms after t_transforms for the train set
+    t_transforms += post_transforms
+
+    for key in subsets.keys():
+        retval[key] = make_subset(
+            subsets[key], transforms=transforms, suffixes=post_transforms
+        )
+        if key == "train":
+            retval["__train__"] = make_subset(
+                subsets[key], transforms=transforms, suffixes=(t_transforms)
+            )
+        if key == "validation":
+            # also use it for validation during training
+            retval["__valid__"] = retval[key]
+
+    if (
+        ("__train__" in retval)
+        and ("train" in retval)
+        and ("__valid__" not in retval)
+    ):
+        # if the dataset does not have a validation set, we use the unaugmented
+        # training set as validation set
+        retval["__valid__"] = retval["train"]
+
+    return retval
+
+
+def get_samples_weights(dataset):
+    """Compute the weights of all the samples of the dataset to balance it
+    using the sampler of the dataloader.
+
+    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    and computes the weights to balance each class in the dataset and the
+    datasets themselves if we have a ConcatDataset.
+
+
+    Parameters
+    ----------
+
+    dataset : torch.utils.data.dataset.Dataset
+        An instance of torch.utils.data.dataset.Dataset
+        ConcatDataset are supported
+
+
+    Returns
+    -------
+
+    samples_weights : :py:class:`torch.Tensor`
+        the weights for all the samples in the dataset given as input
+    """
+    samples_weights = []
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            # Weighting only for binary labels
+            if isinstance(ds._samples[0].label, int):
+                # Groundtruth
+                targets = []
+                for s in ds._samples:
+                    targets.append(s.label)
+                targets = torch.tensor(targets)
+
+                # Count number of samples per class
+                class_sample_count = torch.tensor(
+                    [
+                        (targets == t).sum()
+                        for t in torch.unique(targets, sorted=True)
+                    ]
+                )
+
+                weight = 1.0 / class_sample_count.float()
+
+                samples_weights.append(
+                    torch.tensor([weight[t] for t in targets])
+                )
+
+            else:
+                # We only weight to sample equally from each dataset
+                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
+
+        # Concatenate sample weights from all the datasets
+        samples_weights = torch.cat(samples_weights)
+
+    else:
+        # Weighting only for binary labels
+        if isinstance(dataset._samples[0].label, int):
+            # Groundtruth
+            targets = []
+            for s in dataset._samples:
+                targets.append(s.label)
+            targets = torch.tensor(targets)
+
+            # Count number of samples per class
+            class_sample_count = torch.tensor(
+                [
+                    (targets == t).sum()
+                    for t in torch.unique(targets, sorted=True)
+                ]
+            )
+
+            weight = 1.0 / class_sample_count.float()
+
+            samples_weights = torch.tensor([weight[t] for t in targets])
+
+        else:
+            # Equal weights for non-binary labels
+            samples_weights = torch.ones(len(dataset._samples))
+
+    return samples_weights
+
+
+def get_positive_weights(dataset):
+    """Compute the positive weights of each class of the dataset to balance the
+    BCEWithLogitsLoss criterion.
+
+    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    and computes the positive weights of each class to use them to have
+    a balanced loss.
+
+
+    Parameters
+    ----------
+
+    dataset : torch.utils.data.dataset.Dataset
+        An instance of torch.utils.data.dataset.Dataset
+        ConcatDataset are supported
+
+
+    Returns
+    -------
+
+    positive_weights : :py:class:`torch.Tensor`
+        the positive weight of each class in the dataset given as input
+    """
+    targets = []
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            for s in ds._samples:
+                targets.append(s.label)
+
+    else:
+        for s in dataset._samples:
+            targets.append(s.label)
+
+    targets = torch.tensor(targets)
+
+    # Binary labels
+    if len(list(targets.shape)) == 1:
+        class_sample_count = [
+            float((targets == t).sum().item())
+            for t in torch.unique(targets, sorted=True)
+        ]
+
+        # Divide negatives by positives
+        positive_weights = torch.tensor(
+            [class_sample_count[0] / class_sample_count[1]]
+        ).reshape(-1)
+
+    # Multiclass labels
+    else:
+        class_sample_count = torch.sum(targets, dim=0)
+        negative_class_sample_count = (
+            torch.full((targets.size()[1],), float(targets.size()[0]))
+            - class_sample_count
+        )
+
+        positive_weights = negative_class_sample_count / (
+            class_sample_count + negative_class_sample_count
+        )
+
+    return positive_weights
diff --git a/src/ptbench/data/shenzhen/__init__.py b/src/ptbench/data/shenzhen/__init__.py
index c1488160f553aed8e3e0fab4c882d46d8cffc799..60a004a3afc76239ffae144556feaaa888850ab7 100644
--- a/src/ptbench/data/shenzhen/__init__.py
+++ b/src/ptbench/data/shenzhen/__init__.py
@@ -22,11 +22,19 @@ the daily routine using Philips DR Digital Diagnose systems.
 
 import importlib.resources
 import os
+import random
+
+import torch
+
+from clapper.logging import setup
 
 from ...utils.rc import load_rc
 from ..dataset import JSONDataset
 from ..loader import load_pil_baw, make_delayed
 
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
 _protocols = [
     importlib.resources.files(__name__).joinpath("default.json.bz2"),
     importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
@@ -57,9 +65,367 @@ def _loader(context, sample):
     return make_delayed(sample, _raw_data_loader)
 
 
-dataset = JSONDataset(
-    protocols=_protocols,
-    fieldnames=("data", "label"),
-    loader=_loader,
+json_dataset = JSONDataset(
+    protocols=_protocols, fieldnames=("data", "label"), loader=_loader
 )
 """Shenzhen dataset object."""
+
+
+def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
+    from torchvision import transforms
+
+    from ..transforms import ElasticDeformation, RemoveBlackBorders
+
+    post_transforms = []
+    if RGB:
+        post_transforms = [
+            transforms.Lambda(lambda x: x.convert("RGB")),
+            transforms.ToTensor(),
+        ]
+
+    return make_dataset(
+        [json_dataset.subsets(protocol)],
+        [
+            RemoveBlackBorders(),
+            transforms.Resize(resize_size),
+            transforms.CenterCrop(cc_size),
+        ],
+        [ElasticDeformation(p=0.8)],
+        post_transforms,
+    )
+
+
+def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
+    """Creates a new data set, applying transforms.
+
+    .. note::
+
+       This is a convenience function for our own dataset definitions inside
+       this module, guaranteeting homogenity between dataset definitions
+       provided in this package.  It assumes certain strategies for data
+       augmentation that may not be translatable to other applications.
+
+
+    Parameters
+    ----------
+
+    samples : list
+        List of delayed samples
+
+    transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+
+    prefixes : list
+        A list of data augmentation operations that needs to be applied
+        **before** the transforms above
+
+    suffixes : list
+        A list of data augmentation operations that needs to be applied
+        **after** the transforms above
+
+
+    Returns
+    -------
+
+    subset : :py:class:`ptbench.data.utils.SampleListDataset`
+        A pre-formatted dataset that can be fed to one of our engines
+    """
+    from ...data.utils import SampleListDataset as wrapper
+
+    return wrapper(samples, prefixes + transforms + suffixes)
+
+
+def make_dataset(
+    subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
+):
+    """Creates a new configuration dataset from a list of dictionaries and
+    transforms.
+
+    This function takes as input a list of dictionaries as those that can be
+    returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
+    mapping protocol names (such as ``train``, ``dev`` and ``test``) to
+    :py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
+    transforms, and returns a dictionary applying
+    :py:class:`ptbench.data.utils.SampleListDataset` to these
+    lists, and our standard data augmentation if a ``train`` set exists.
+
+    For example, if ``subsets`` is composed of two sets named ``train`` and
+    ``test``, this function will yield a dictionary with the following entries:
+
+    * ``__train__``: Wraps the ``train`` subset, includes data augmentation
+      (note: datasets with names starting with ``_`` (underscore) are excluded
+      from prediction and evaluation by default, as they contain data
+      augmentation transformations.)
+    * ``train``: Wraps the ``train`` subset, **without** data augmentation
+    * ``test``: Wraps the ``test`` subset, **without** data augmentation
+
+    .. note::
+
+       This is a convenience function for our own dataset definitions inside
+       this module, guaranteeting homogenity between dataset definitions
+       provided in this package.  It assumes certain strategies for data
+       augmentation that may not be translatable to other applications.
+
+
+    Parameters
+    ----------
+
+    subsets : list
+        A list of dictionaries that contains the delayed sample lists
+        for a number of named lists. The subsets will be aggregated in one
+        final subset. If one of the keys is ``train``, our standard dataset
+        augmentation transforms are appended to the definition of that subset.
+        All other subsets remain un-augmented.
+
+    transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+
+    t_transforms : list
+        A list of transforms that needs to be applied to the train samples
+
+    post_transforms : list
+        A list of transforms that needs to be applied to all samples in the set
+        after all the other transforms
+
+
+    Returns
+    -------
+
+    dataset : dict
+        A pre-formatted dataset that can be fed to one of our engines. It maps
+        string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
+    """
+
+    retval = {}
+
+    if len(subsets_groups) == 1:
+        subsets = subsets_groups[0]
+    else:
+        # If multiple subsets groups: aggregation
+        aggregated_subsets = {}
+        for subsets in subsets_groups:
+            for key in subsets.keys():
+                if key in aggregated_subsets:
+                    aggregated_subsets[key] += subsets[key]
+                    # Shuffle if data comes from multiple datasets
+                    random.shuffle(aggregated_subsets[key])
+                else:
+                    aggregated_subsets[key] = subsets[key]
+        subsets = aggregated_subsets
+
+    # Add post_transforms after t_transforms for the train set
+    t_transforms += post_transforms
+
+    for key in subsets.keys():
+        retval[key] = make_subset(
+            subsets[key], transforms=transforms, suffixes=post_transforms
+        )
+        if key == "train":
+            retval["__train__"] = make_subset(
+                subsets[key], transforms=transforms, suffixes=(t_transforms)
+            )
+        if key == "validation":
+            # also use it for validation during training
+            retval["__valid__"] = retval[key]
+
+    if (
+        ("__train__" in retval)
+        and ("train" in retval)
+        and ("__valid__" not in retval)
+    ):
+        # if the dataset does not have a validation set, we use the unaugmented
+        # training set as validation set
+        retval["__valid__"] = retval["train"]
+
+    return retval
+
+
+def get_samples_weights(dataset):
+    """Compute the weights of all the samples of the dataset to balance it
+    using the sampler of the dataloader.
+
+    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    and computes the weights to balance each class in the dataset and the
+    datasets themselves if we have a ConcatDataset.
+
+
+    Parameters
+    ----------
+
+    dataset : torch.utils.data.dataset.Dataset
+        An instance of torch.utils.data.dataset.Dataset
+        ConcatDataset are supported
+
+
+    Returns
+    -------
+
+    samples_weights : :py:class:`torch.Tensor`
+        the weights for all the samples in the dataset given as input
+    """
+    samples_weights = []
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            # Weighting only for binary labels
+            if isinstance(ds._samples[0].label, int):
+                # Groundtruth
+                targets = []
+                for s in ds._samples:
+                    targets.append(s.label)
+                targets = torch.tensor(targets)
+
+                # Count number of samples per class
+                class_sample_count = torch.tensor(
+                    [
+                        (targets == t).sum()
+                        for t in torch.unique(targets, sorted=True)
+                    ]
+                )
+
+                weight = 1.0 / class_sample_count.float()
+
+                samples_weights.append(
+                    torch.tensor([weight[t] for t in targets])
+                )
+
+            else:
+                # We only weight to sample equally from each dataset
+                samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
+
+        # Concatenate sample weights from all the datasets
+        samples_weights = torch.cat(samples_weights)
+
+    else:
+        # Weighting only for binary labels
+        if isinstance(dataset._samples[0].label, int):
+            # Groundtruth
+            targets = []
+            for s in dataset._samples:
+                targets.append(s.label)
+            targets = torch.tensor(targets)
+
+            # Count number of samples per class
+            class_sample_count = torch.tensor(
+                [
+                    (targets == t).sum()
+                    for t in torch.unique(targets, sorted=True)
+                ]
+            )
+
+            weight = 1.0 / class_sample_count.float()
+
+            samples_weights = torch.tensor([weight[t] for t in targets])
+
+        else:
+            # Equal weights for non-binary labels
+            samples_weights = torch.ones(len(dataset._samples))
+
+    return samples_weights
+
+
+def get_positive_weights(dataset):
+    """Compute the positive weights of each class of the dataset to balance the
+    BCEWithLogitsLoss criterion.
+
+    This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
+    and computes the positive weights of each class to use them to have
+    a balanced loss.
+
+
+    Parameters
+    ----------
+
+    dataset : torch.utils.data.dataset.Dataset
+        An instance of torch.utils.data.dataset.Dataset
+        ConcatDataset are supported
+
+
+    Returns
+    -------
+
+    positive_weights : :py:class:`torch.Tensor`
+        the positive weight of each class in the dataset given as input
+    """
+    targets = []
+
+    if isinstance(dataset, torch.utils.data.ConcatDataset):
+        for ds in dataset.datasets:
+            for s in ds._samples:
+                targets.append(s.label)
+
+    else:
+        for s in dataset._samples:
+            targets.append(s.label)
+
+    targets = torch.tensor(targets)
+
+    # Binary labels
+    if len(list(targets.shape)) == 1:
+        class_sample_count = [
+            float((targets == t).sum().item())
+            for t in torch.unique(targets, sorted=True)
+        ]
+
+        # Divide negatives by positives
+        positive_weights = torch.tensor(
+            [class_sample_count[0] / class_sample_count[1]]
+        ).reshape(-1)
+
+    # Multiclass labels
+    else:
+        class_sample_count = torch.sum(targets, dim=0)
+        negative_class_sample_count = (
+            torch.full((targets.size()[1],), float(targets.size()[0]))
+            - class_sample_count
+        )
+
+        positive_weights = negative_class_sample_count / (
+            class_sample_count + negative_class_sample_count
+        )
+
+    return positive_weights
+
+
+def return_subsets(dataset):
+    train_dataset = None
+    validation_dataset = None
+    extra_validation_datasets = None
+    predict_dataset = None
+
+    if "__train__" in dataset:
+        logger.info("Found (dedicated) '__train__' set for training")
+        train_dataset = dataset["__train__"]
+    else:
+        train_dataset = dataset["train"]
+
+    if "__valid__" in dataset:
+        logger.info("Found (dedicated) '__valid__' set for validation")
+        validation_dataset = dataset["__valid__"]
+
+    if "__extra_valid__" in dataset:
+        if not isinstance(dataset["__extra_valid__"], list):
+            raise RuntimeError(
+                f"If present, dataset['__extra_valid__'] must be a list, "
+                f"but you passed a {type(dataset['__extra_valid__'])}, "
+                f"which is invalid."
+            )
+        logger.info(
+            f"Found {len(dataset['__extra_valid__'])} extra validation "
+            f"set(s) to be tracked during training"
+        )
+        logger.info(
+            "Extra validation sets are NOT used for model checkpointing!"
+        )
+        extra_validation_datasets = dataset["__extra_valid__"]
+    else:
+        extra_validation_datasets = None
+
+    predict_dataset = dataset
+
+    return (
+        train_dataset,
+        validation_dataset,
+        extra_validation_datasets,
+        predict_dataset,
+    )
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..8347cd56350e6de2901ca3793c5db1c00b9af55a
--- /dev/null
+++ b/src/ptbench/data/shenzhen/default.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from clapper.logging import setup
+
+from ..base_datamodule import BaseDataModule
+from . import _maker, return_subsets
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+class DefaultModule(BaseDataModule):
+    def __init__(
+        self,
+        train_batch_size=1,
+        predict_batch_size=1,
+        drop_incomplete_batch=False,
+        multiproc_kwargs=None,
+    ):
+        super().__init__(
+            train_batch_size=train_batch_size,
+            predict_batch_size=predict_batch_size,
+            drop_incomplete_batch=drop_incomplete_batch,
+            multiproc_kwargs=multiproc_kwargs,
+        )
+
+    def setup(self, stage: str):
+        self.dataset = _maker("default")
+        (
+            self.train_dataset,
+            self.validation_dataset,
+            self.extra_validation_datasets,
+            self.predict_dataset,
+        ) = return_subsets(self.dataset)
+
+
+datamodule = DefaultModule
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index fe5a2b851806b108ad63e48edd9f4f0809536ebc..083d85d9566ca07c5dce3bf22a1d5265b5971dbf 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -8,7 +8,6 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
 from lightning.pytorch import seed_everything
 
-from ..data.datamodule import DataModule
 from ..utils.checkpointer import get_checkpoint
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@@ -45,7 +44,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @click.option(
-    "--dataset",
+    "--datamodule",
     "-d",
     help="A dictionary mapping string keys to "
     "torch.utils.data.dataset.Dataset instances implementing datasets "
@@ -233,7 +232,7 @@ def train(
     drop_incomplete_batch,
     criterion,
     criterion_valid,
-    dataset,
+    datamodule,
     checkpoint_period,
     accelerator,
     seed,
@@ -290,9 +289,9 @@ def train(
     else:
         batch_chunk_size = batch_size // batch_chunk_count
 
-    datamodule = DataModule(
-        dataset,
+    datamodule = datamodule(
         train_batch_size=batch_chunk_size,
+        drop_incomplete_batch=drop_incomplete_batch,
         multiproc_kwargs=multiproc_kwargs,
     )
     # Manually calling these as we need to access some values to reweight the criterion