diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
index 635c9ed9b7ed23a255ba9f5e5f79828de17b7f17..b2d3ab1d48a9bd0d3a66e4a66c50fcc2300be2c9 100644
--- a/src/ptbench/configs/datasets/shenzhen/default.py
+++ b/src/ptbench/configs/datasets/shenzhen/default.py
@@ -10,6 +10,8 @@
 * See :py:mod:`ptbench.data.shenzhen` for dataset details
 """
 
+from ....data.shenzhen.datamodule import ShenzhenDataModule
 from . import _maker
 
 dataset = _maker("default")
+datamodule = ShenzhenDataModule
diff --git a/src/ptbench/data/shenzhen/datamodule.py b/src/ptbench/data/shenzhen/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..60de2efb1fa1d79be5a01054578c3718904f1d2b
--- /dev/null
+++ b/src/ptbench/data/shenzhen/datamodule.py
@@ -0,0 +1,140 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import lightning as pl
+import torch
+
+from clapper.logging import setup
+from torch.utils.data import DataLoader, WeightedRandomSampler
+
+from ptbench.configs.datasets import get_samples_weights
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+class ShenzhenDataModule(pl.LightningDataModule):
+    def __init__(
+        self,
+        dataset,
+        train_batch_size=1,
+        predict_batch_size=1,
+        drop_incomplete_batch=False,
+        multiproc_kwargs=None,
+    ):
+        super().__init__()
+
+        self.dataset = dataset
+
+        self.train_batch_size = train_batch_size
+        self.predict_batch_size = predict_batch_size
+
+        self.drop_incomplete_batch = drop_incomplete_batch
+        self.pin_memory = (
+            torch.cuda.is_available()
+        )  # should only be true if GPU available and using it
+
+        self.multiproc_kwargs = multiproc_kwargs
+
+    def setup(self, stage: str):
+        if stage == "fit":
+            if "__train__" in self.dataset:
+                logger.info("Found (dedicated) '__train__' set for training")
+                self.train_dataset = self.dataset["__train__"]
+            else:
+                self.train_dataset = self.dataset["train"]
+
+            if "__valid__" in self.dataset:
+                logger.info("Found (dedicated) '__valid__' set for validation")
+                self.validation_dataset = self.dataset["__valid__"]
+
+            if "__extra_valid__" in self.dataset:
+                if not isinstance(self.dataset["__extra_valid__"], list):
+                    raise RuntimeError(
+                        f"If present, dataset['__extra_valid__'] must be a list, "
+                        f"but you passed a {type(self.dataset['__extra_valid__'])}, "
+                        f"which is invalid."
+                    )
+                logger.info(
+                    f"Found {len(self.dataset['__extra_valid__'])} extra validation "
+                    f"set(s) to be tracked during training"
+                )
+                logger.info(
+                    "Extra validation sets are NOT used for model checkpointing!"
+                )
+                self.extra_validation_datasets = self.dataset["__extra_valid__"]
+            else:
+                self.extra_validation_datasets = None
+
+        if stage == "predict":
+            self.predict_dataset = []
+
+            for split_key in self.dataset.keys():
+                if split_key.startswith("_"):
+                    logger.info(
+                        f"Skipping dataset '{split_key}' (not to be evaluated)"
+                    )
+                    continue
+
+                else:
+                    self.predict_dataset.append(self.dataset[split_key])
+
+    def train_dataloader(self):
+        train_samples_weights = get_samples_weights(self.train_dataset)
+
+        train_sampler = WeightedRandomSampler(
+            train_samples_weights, len(train_samples_weights), replacement=True
+        )
+
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.train_batch_size,
+            drop_last=self.drop_incomplete_batch,
+            pin_memory=self.pin_memory,
+            sampler=train_sampler,
+            **self.multiproc_kwargs,
+        )
+
+    def val_dataloader(self):
+        loaders_dict = {}
+
+        val_loader = DataLoader(
+            dataset=self.validation_dataset,
+            batch_size=self.train_batch_size,
+            shuffle=False,
+            drop_last=False,
+            pin_memory=self.pin_memory,
+            **self.multiproc_kwargs,
+        )
+
+        loaders_dict["validation_loader"] = val_loader
+
+        if self.extra_validation_datasets is not None:
+            for set_idx, extra_set in enumerate(self.extra_validation_datasets):
+                extra_val_loader = DataLoader(
+                    dataset=extra_set,
+                    batch_size=self.train_batch_size,
+                    shuffle=False,
+                    drop_last=False,
+                    pin_memory=self.pin_memory,
+                    **self.multiproc_kwargs,
+                )
+
+                loaders_dict[
+                    f"extra_validation_loader{set_idx}"
+                ] = extra_val_loader
+
+        return loaders_dict
+
+    def predict_dataloader(self):
+        loaders_dict = {}
+
+        for set_idx, pred_set in enumerate(self.predict_dataset):
+            loaders_dict[set_idx] = DataLoader(
+                dataset=pred_set,
+                batch_size=self.predict_batch_size,
+                shuffle=False,
+                pin_memory=self.pin_memory,
+            )
+
+        return loaders_dict
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index b266ae6221cf9a925ff941f1c99bdfdd044fa23f..962a761cce0b2a82bd6f41a60c92c49333f59da8 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -1,6 +1,8 @@
 import csv
 import time
 
+from collections import defaultdict
+
 import numpy
 
 from lightning.pytorch import Callback
@@ -17,6 +19,7 @@ class LoggingCallback(Callback):
         super().__init__()
         self.training_loss = []
         self.validation_loss = []
+        self.extra_validation_loss = defaultdict(list)
         self.start_training_time = 0
         self.start_epoch_time = 0
 
@@ -37,6 +40,13 @@ class LoggingCallback(Callback):
     ):
         self.validation_loss.append(outputs["validation_loss"].item())
 
+        if len(outputs) > 1:
+            extra_validation_keys = outputs.keys().remove("validation_loss")
+            for extra_validation_loss_key in extra_validation_keys:
+                self.extra_validation_loss[extra_validation_loss_key].append(
+                    outputs[extra_validation_loss_key]
+                )
+
     def on_validation_epoch_end(self, trainer, pl_module):
         self.resource_monitor.trigger_summary()
 
@@ -52,6 +62,15 @@ class LoggingCallback(Callback):
         self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"])
         self.log("validation_loss", numpy.average(self.validation_loss))
 
+        if len(self.extra_validation_loss) > 0:
+            for (
+                extra_valid_loss_key,
+                extra_valid_loss_values,
+            ) in self.extra_validation_loss.items:
+                self.log(
+                    extra_valid_loss_key, numpy.average(extra_valid_loss_values)
+                )
+
         queue_retries = 0
         # In case the resource monitor takes longer to fetch data from the queue, we wait
         # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index a85a3da566691922323fbeb8a56d3472def48389..2c8bdc55f161ce567ddd7cb6641ac728ce3b7dbd 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -149,9 +149,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device):
 
 def run(
     model,
-    data_loader,
-    valid_loader,
-    extra_valid_loaders,
+    datamodule,
     checkpoint_period,
     accelerator,
     arguments,
@@ -263,4 +261,4 @@ def run(
             callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
         )
 
-        _ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint)
+        _ = trainer.fit(model, datamodule, ckpt_path=checkpoint)
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 125867bda1aa4f6e2317708cc5010d9120518f46..cae11375b66e8cc24287609363474a36065a2016 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -141,7 +141,7 @@ class PASA(pl.LightningModule):
 
         return {"loss": training_loss}
 
-    def validation_step(self, batch, batch_idx):
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[1]
         labels = batch[2]
 
@@ -159,9 +159,12 @@ class PASA(pl.LightningModule):
         )
         validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
-        return {"validation_loss": validation_loss}
+        if dataloader_idx == 0:
+            return {"validation_loss": validation_loss}
+        else:
+            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
-    def predict_step(self, batch, batch_idx, grad_cams=False):
+    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         names = batch[0]
         images = batch[1]
 
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 12c5a287f5682340ecb1275f4439ebd4056c657b..eb6910cdce997df94155e602591508b99ec512dd 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -13,25 +13,6 @@ from ..utils.checkpointer import get_checkpoint
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-def set_reproducible_cuda():
-    """Turns-off all CUDA optimizations that would affect reproducibility.
-
-    For full reproducibility, also ensure not to use multiple (parallel) data
-    lowers.  That is setup ``num_workers=0``.
-
-    Reference: `PyTorch page for reproducibility
-    <https://pytorch.org/docs/stable/notes/randomness.html>`_.
-    """
-    import torch.backends.cudnn
-
-    # ensure to use only optimization algos for cuda that are known to have
-    # a deterministic effect (not random)
-    torch.backends.cudnn.deterministic = True
-
-    # turns off any optimization tricks
-    torch.backends.cudnn.benchmark = False
-
-
 @click.command(
     entry_point_group="ptbench.config",
     cls=ConfigCommand,
@@ -62,6 +43,12 @@ def set_reproducible_cuda():
     required=True,
     cls=ResourceOption,
 )
+@click.option(
+    "--datamodule",
+    help="A torch.nn.Module instance implementing the network to be trained",
+    required=True,
+    cls=ResourceOption,
+)
 @click.option(
     "--dataset",
     "-d",
@@ -235,7 +222,7 @@ def set_reproducible_cuda():
 )
 @click.option(
     "--resume-from",
-    help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a  model checkpoint.",
+    help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.",
     type=str,
     required=False,
     default=None,
@@ -251,6 +238,7 @@ def train(
     drop_incomplete_batch,
     criterion,
     criterion_valid,
+    datamodule,
     dataset,
     checkpoint_period,
     accelerator,
@@ -277,45 +265,13 @@ def train(
     import torch.nn
 
     from torch.nn import BCEWithLogitsLoss
-    from torch.utils.data import DataLoader, WeightedRandomSampler
+    from torch.utils.data import DataLoader
 
-    from ..configs.datasets import get_positive_weights, get_samples_weights
+    from ..configs.datasets import get_positive_weights
     from ..engine.trainer import run
 
     seed_everything(seed)
 
-    use_dataset = dataset
-    validation_dataset = None
-    extra_validation_datasets = []
-
-    if isinstance(dataset, dict):
-        if "__train__" in dataset:
-            logger.info("Found (dedicated) '__train__' set for training")
-            use_dataset = dataset["__train__"]
-        else:
-            use_dataset = dataset["train"]
-
-        if "__valid__" in dataset:
-            logger.info("Found (dedicated) '__valid__' set for validation")
-            logger.info("Will checkpoint lowest loss model on validation set")
-            validation_dataset = dataset["__valid__"]
-
-        if "__extra_valid__" in dataset:
-            if not isinstance(dataset["__extra_valid__"], list):
-                raise RuntimeError(
-                    f"If present, dataset['__extra_valid__'] must be a list, "
-                    f"but you passed a {type(dataset['__extra_valid__'])}, "
-                    f"which is invalid."
-                )
-            logger.info(
-                f"Found {len(dataset['__extra_valid__'])} extra validation "
-                f"set(s) to be tracked during training"
-            )
-            logger.info(
-                "Extra validation sets are NOT used for model checkpointing!"
-            )
-            extra_validation_datasets = dataset["__extra_valid__"]
-
     # PyTorch dataloader
     multiproc_kwargs = dict()
     if parallel < 0:
@@ -340,31 +296,25 @@ def train(
     else:
         batch_chunk_size = batch_size // batch_chunk_count
 
-    # Create weighted random sampler
-    train_samples_weights = get_samples_weights(use_dataset)
-    train_sampler = WeightedRandomSampler(
-        train_samples_weights, len(train_samples_weights), replacement=True
+    datamodule = datamodule(
+        dataset,
+        train_batch_size=batch_chunk_size,
+        multiproc_kwargs=multiproc_kwargs,
     )
+    # Manually calling these as we need to access some values to reweight the criterion
+    datamodule.prepare_data()
+    datamodule.setup(stage="fit")
+
+    train_dataset = datamodule.train_dataset
+    validation_dataset = datamodule.validation_dataset
 
     # Redefine a weighted criterion if possible
     if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
-        positive_weights = get_positive_weights(use_dataset)
+        positive_weights = get_positive_weights(train_dataset)
         model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
     else:
         logger.warning("Weighted criterion not supported")
 
-    # PyTorch dataloader
-
-    data_loader = DataLoader(
-        dataset=use_dataset,
-        batch_size=batch_chunk_size,
-        drop_last=drop_incomplete_batch,
-        pin_memory=torch.cuda.is_available(),
-        sampler=train_sampler,
-        **multiproc_kwargs,
-    )
-
-    valid_loader = None
     if validation_dataset is not None:
         # Redefine a weighted valid criterion if possible
         if (
@@ -378,27 +328,6 @@ def train(
         else:
             logger.warning("Weighted valid criterion not supported")
 
-        valid_loader = DataLoader(
-            dataset=validation_dataset,
-            batch_size=batch_chunk_size,
-            shuffle=False,
-            drop_last=False,
-            pin_memory=torch.cuda.is_available(),
-            **multiproc_kwargs,
-        )
-
-    extra_valid_loaders = [
-        DataLoader(
-            dataset=k,
-            batch_size=batch_chunk_size,
-            shuffle=False,
-            drop_last=False,
-            pin_memory=torch.cuda.is_available(),
-            **multiproc_kwargs,
-        )
-        for k in extra_validation_datasets
-    ]
-
     # Create z-normalization model layer if needed
     if normalization == "imagenet":
         model.normalizer.set_mean_std(
@@ -407,7 +336,9 @@ def train(
         logger.info("Z-normalization with ImageNet mean and std")
     elif normalization == "current":
         # Compute mean/std of current train subset
-        temp_dl = DataLoader(dataset=use_dataset, batch_size=len(use_dataset))
+        temp_dl = DataLoader(
+            dataset=train_dataset, batch_size=len(train_dataset)
+        )
 
         data = next(iter(temp_dl))
         mean = data[1].mean(dim=[0, 2, 3])
@@ -446,9 +377,7 @@ def train(
 
     run(
         model=model,
-        data_loader=data_loader,
-        valid_loader=valid_loader,
-        extra_valid_loaders=extra_valid_loaders,
+        datamodule=datamodule,
         checkpoint_period=checkpoint_period,
         accelerator=accelerator,
         arguments=arguments,