diff --git a/src/mednet/libs/classification/engine/predictor.py b/src/mednet/libs/classification/engine/predictor.py
index 1701a23ed8a2b517d9ad5fb233e8c365fb45c0bf..771d918f7ef454656038380d743d2da45e7745cd 100644
--- a/src/mednet/libs/classification/engine/predictor.py
+++ b/src/mednet/libs/classification/engine/predictor.py
@@ -31,7 +31,7 @@ def run(
     | MultiClassPredictionSplit
     | None
 ):
-    """Run inference on input data, outputs csv files with predictions.
+    """Run inference on input data, output predictions.
 
     Parameters
     ----------
diff --git a/src/mednet/libs/classification/scripts/predict.py b/src/mednet/libs/classification/scripts/predict.py
index a6f45a113fda61bc73a54ce4e95ef3f72dfdb772..0a0f4451ccd51b0bcf05d9489b0475e08b2378fd 100644
--- a/src/mednet/libs/classification/scripts/predict.py
+++ b/src/mednet/libs/classification/scripts/predict.py
@@ -46,7 +46,6 @@ def predict(
     """Run inference (generates scores) on all input images, using a pre-trained model."""
 
     import json
-    import shutil
 
     from mednet.libs.classification.engine.predictor import run
     from mednet.libs.common.engine.device import DeviceManager
@@ -55,24 +54,20 @@ def predict(
         save_json_data,
         setup_datamodule,
     )
+    from mednet.libs.common.scripts.utils import save_json_with_backup
 
-    predictions_file = output_folder / "predictions.json"
-    predictions_file.parent.mkdir(parents=True, exist_ok=True)
+    predictions_meta_file = output_folder / "predictions.meta.json"
+    predictions_meta_file.parent.mkdir(parents=True, exist_ok=True)
 
     setup_datamodule(datamodule, model, batch_size, parallel)
     model = load_checkpoint(model, weight)
     device_manager = DeviceManager(device)
-    save_json_data(datamodule, model, predictions_file, device_manager)
+    save_json_data(datamodule, model, device_manager, predictions_meta_file)
 
     predictions = run(model, datamodule, device_manager)
 
-    if predictions_file.exists():
-        backup = predictions_file.parent / (predictions_file.name + "~")
-        logger.warning(
-            f"Output predictions file `{str(predictions_file)}` exists - "
-            f"backing it up to `{str(backup)}`...",
-        )
-        shutil.copy(predictions_file, backup)
+    predictions_file = output_folder / "predictions.json"
+    save_json_with_backup(predictions_file, predictions)
 
     with predictions_file.open("w") as f:
         json.dump(predictions, f, indent=2)
diff --git a/src/mednet/libs/common/scripts/predict.py b/src/mednet/libs/common/scripts/predict.py
index d38c38411257c95a293d491bb178470d4560da64..422b9e56585649a124e1d9c01428ed07178b23b7 100644
--- a/src/mednet/libs/common/scripts/predict.py
+++ b/src/mednet/libs/common/scripts/predict.py
@@ -7,15 +7,15 @@ import pathlib
 import typing
 
 import click
-from clapper.click import ResourceOption
+import mednet.libs.common.data.datamodule
+import mednet.libs.common.models.model
 from clapper.logging import setup
-from mednet.libs.common.models.model import Model
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-def reusable_options(f):
-    """Wrap reusable predict script options (for ``experiment``).
+def reusable_options(f: typing.Callable):
+    """Wrap reusable predict script options for other scripts.
 
     This decorator equips the target function ``f`` with all (reusable)
     ``predict`` script options.
@@ -30,6 +30,7 @@ def reusable_options(f):
     -------
         The decorated version of function ``f``
     """
+    from clapper.click import ResourceOption
 
     @click.option(
         "--output-folder",
@@ -121,21 +122,24 @@ def reusable_options(f):
 
 
 def setup_datamodule(
-    datamodule,
-    model,
-    batch_size,
-    parallel,
+    datamodule: mednet.libs.common.data.datamodule.ConcatDataModule,
+    model: mednet.libs.common.models.model.Model,
+    batch_size: int,
+    parallel: int,
 ) -> None:  # numpydoc ignore=PR01
     """Configure and set up the datamodule."""
+
     datamodule.batch_size = batch_size
     datamodule.parallel = parallel
-    datamodule.model_transforms = model.model_transforms
+    datamodule.model_transforms = list(model.model_transforms)
 
     datamodule.prepare_data()
     datamodule.setup(stage="predict")
 
 
-def load_checkpoint(model: Model, weight: pathlib.Path) -> Model:
+def load_checkpoint(
+    model: mednet.libs.common.models.model.Model, weight: pathlib.Path
+) -> mednet.libs.common.models.model.Model:
     """Load a model checkpoint for prediction.
 
     Parameters
@@ -162,10 +166,10 @@ def load_checkpoint(model: Model, weight: pathlib.Path) -> Model:
 
 
 def save_json_data(
-    datamodule,
-    model,
-    output,
+    datamodule: mednet.libs.common.data.datamodule.ConcatDataModule,
+    model: mednet.libs.common.models.model.Model,
     device_manager,
+    output_file: pathlib.Path,
 ) -> None:  # numpydoc ignore=PR01
     """Save prediction hyperparameters into a .json file."""
 
@@ -187,4 +191,4 @@ def save_json_data(
     )
     json_data.update(model_summary(model))
     json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
-    save_json_with_backup(output.with_suffix(".meta.json"), json_data)
+    save_json_with_backup(output_file, json_data)
diff --git a/src/mednet/libs/segmentation/engine/evaluator.py b/src/mednet/libs/segmentation/engine/evaluator.py
index 82ec4bbba6eb56a14fa55551be8a9081e9794e6c..46b2d9ff21b0fdf9807d171bbc632b98503d0f84 100644
--- a/src/mednet/libs/segmentation/engine/evaluator.py
+++ b/src/mednet/libs/segmentation/engine/evaluator.py
@@ -370,7 +370,7 @@ def load_count(
     data = numpy.zeros((len(thresholds), 4), dtype=numpy.uint64)
     for sample in tqdm(predictions, desc="sample"):
         with h5py.File(prediction_path / sample[1], "r") as f:
-            pred = numpy.array(f.get("img"))  # float32
+            pred = numpy.array(f.get("prediction"))  # float32
             gt = numpy.array(f.get("target"))  # boolean
             mask = numpy.array(f.get("mask"))  # boolean
         data += numpy.array(
@@ -411,7 +411,7 @@ def load_predictions(
 
     # peak prediction size and number of samples
     with h5py.File(prediction_path / predictions[0][1], "r") as f:
-        elements = numpy.array(f.get("img").shape).prod()
+        elements = numpy.array(f.get("prediction").shape).prod()
     size = len(predictions) * elements
     logger.info(
         f"Data loading will require ({elements} x {len(predictions)} x 5 =) "
@@ -424,7 +424,7 @@ def load_predictions(
     for i, sample in enumerate(tqdm(predictions, desc="sample")):
         with h5py.File(prediction_path / sample[1], "r") as f:
             mask = numpy.array(f.get("mask"))  # boolean
-            pred = numpy.array(f.get("img"))  # float32
+            pred = numpy.array(f.get("prediction"))  # float32
             pred *= mask.astype(numpy.float32)
             gt = numpy.array(f.get("target"))  # boolean
             gt &= mask
diff --git a/src/mednet/libs/segmentation/engine/predictor.py b/src/mednet/libs/segmentation/engine/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8371c9a4dc754d875d306d500d45f418913b2f5a
--- /dev/null
+++ b/src/mednet/libs/segmentation/engine/predictor.py
@@ -0,0 +1,200 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+import pathlib
+import typing
+
+import h5py
+import lightning.pytorch
+import lightning.pytorch.callbacks
+import torch.utils.data
+import tqdm
+from mednet.libs.common.engine.device import DeviceManager
+
+logger = logging.getLogger("mednet")
+
+
+class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
+    """Write HDF5 files for each sample processed by our model.
+
+    Objects of this class can also keep track of samples written to disk and
+    return a summary list.
+
+    Parameters
+    ----------
+    output_folder
+        Base directory where to write predictions to.
+
+    write_interval
+        When will this callback be active.
+    """
+
+    def __init__(
+        self,
+        output_folder: pathlib.Path,
+        write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"] = "batch",
+    ):
+        super().__init__(write_interval=write_interval)
+        self.output_folder = output_folder
+        self._written: list[list[str]] = []
+
+    def write_on_batch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+        prediction: typing.Any,
+        batch_indices: typing.Sequence[int] | None,
+        batch: typing.Any,
+        batch_idx: int,
+        dataloader_idx: int,
+    ) -> None:
+        """Write batch predictions to disk.
+
+        Parameters
+        ----------
+        trainer
+            The trainer being used.
+        pl_module
+            The pytorch module.
+        prediction
+            The actual predictions to record.
+        batch_indices
+            The relative position of samples on the epoch.
+        batch
+            The current batch.
+        batch_idx
+            Index of the batch overall.
+        dataloader_idx
+            Index of the dataloader overall.
+        """
+        for k, p in enumerate(prediction):
+            stem = pathlib.Path(p[0]).with_suffix(".hdf5")
+            output_path = self.output_folder / stem
+            tqdm.tqdm.write(f"`{p[0]}` -> `{str(output_path)}`")
+            output_path.parent.mkdir(parents=True, exist_ok=True)
+            with h5py.File(output_path, "w") as f:
+                f.create_dataset(
+                    "image",
+                    data=batch[0][k].numpy(),
+                    compression="gzip",
+                    compression_opts=9,
+                )
+                f.create_dataset(
+                    "prediction",
+                    data=p[3].numpy().squeeze(0),
+                    compression="gzip",
+                    compression_opts=9,
+                )
+                f.create_dataset(
+                    "target",
+                    data=(batch[1]["target"][k].squeeze(0).numpy() > 0.5),
+                    compression="gzip",
+                    compression_opts=9,
+                )
+                f.create_dataset(
+                    "mask",
+                    data=(batch[1]["mask"][k].squeeze(0).numpy() > 0.5),
+                    compression="gzip",
+                    compression_opts=9,
+                )
+            self._written.append([p[0], str(stem)])
+
+    def written(self) -> list[list[str]]:
+        """Summary of written objects.
+
+        Also resets the internal state.
+
+        Returns
+        -------
+            A list containing a summary of all samples written.
+        """
+        retval = self._written
+        self._written = []
+        return retval
+
+
+def run(
+    model: lightning.pytorch.LightningModule,
+    datamodule: lightning.pytorch.LightningDataModule,
+    device_manager: DeviceManager,
+    output_folder: pathlib.Path,
+) -> dict[str, list[list[str]]] | list[list[list[str]]] | list[list[str]] | None:
+    """Run inference on input data, output predictions.
+
+    Parameters
+    ----------
+    model
+        Neural network model (e.g. pasa).
+    datamodule
+        The lightning DataModule to use for training **and** validation.
+    device_manager
+        An internal device representation, to be used for training and
+        validation.  This representation can be converted into a pytorch device
+        or a lightning accelerator setup.
+    output_folder
+        Folder where to store HDF5 representations of probability maps.
+
+    Returns
+    -------
+        A JSON-able representation of sample data stored at ``output_folder``.
+        For every split (dataloader), a list of samples in the form
+        ``[sample-name, hdf5-path]`` is returned.  In the cases where the
+        ``predict_dataloader()`` returns a single loader, we then return a
+        list.  A dictionary is returned in case ``predict_dataloader()`` also
+        returns a dictionary.
+
+    Raises
+    ------
+    TypeError
+        If the DataModule's ``predict_dataloader()`` method does not return any
+        of the types described above.
+    """
+
+    from lightning.pytorch.loggers.logger import DummyLogger
+
+    writer = _HDF5Writer(output_folder)
+
+    accelerator, devices = device_manager.lightning_accelerator()
+    trainer = lightning.pytorch.Trainer(
+        accelerator=accelerator,
+        devices=devices,
+        logger=DummyLogger(),
+        callbacks=[writer],
+    )
+
+    dataloaders = datamodule.predict_dataloader()
+
+    if isinstance(dataloaders, torch.utils.data.DataLoader):
+        logger.info("Running prediction on a single dataloader...")
+        trainer.predict(model, dataloaders, return_predictions=False)
+        return writer.written()
+
+    if isinstance(dataloaders, list):
+        retval_list = []
+        for k, dataloader in enumerate(dataloaders):
+            logger.info(f"Running prediction on split `{k}`...")
+            trainer.predict(model, dataloader, return_predictions=False)
+            retval_list.append(writer.written())
+        return retval_list
+
+    if isinstance(dataloaders, dict):
+        retval_dict = {}
+        for name, dataloader in dataloaders.items():
+            logger.info(f"Running prediction on `{name}` split...")
+            trainer.predict(model, dataloader, return_predictions=False)
+            retval_dict[name] = writer.written()
+        return retval_dict
+
+    if dataloaders is None:
+        logger.warning("Datamodule did not return any prediction dataloaders!")
+        return None
+
+    # if you get to this point, then the user is returning something that is
+    # not supported - complain!
+    raise TypeError(
+        f"Datamodule returned strangely typed prediction "
+        f"dataloaders: `{type(dataloaders)}` - Please write code "
+        f"to support this use-case.",
+    )
diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py
index dbe12bb33a83f89f38d0d4e798912215e104d4fe..191cde3884a606c1d379e017aa6d49aefc2acd08 100644
--- a/src/mednet/libs/segmentation/scripts/evaluate.py
+++ b/src/mednet/libs/segmentation/scripts/evaluate.py
@@ -15,7 +15,7 @@ from mednet.libs.segmentation.engine.evaluator import SUPPORTED_METRIC_TYPE
 logger = setup("mednet")
 
 
-def _validate_threshold(threshold: float | str, splits: list[str]):
+def validate_threshold(threshold: float | str, splits: list[str]):
     """Validate the user threshold selection and returns parsed threshold.
 
     Parameters
@@ -190,7 +190,7 @@ def evaluate(
     json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
     save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data)
 
-    threshold = _validate_threshold(threshold, predict_data)
+    threshold = validate_threshold(threshold, predict_data)
     threshold_list = numpy.arange(
         0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64
     )
diff --git a/src/mednet/libs/segmentation/scripts/predict.py b/src/mednet/libs/segmentation/scripts/predict.py
index 9c5c3d44e1b35c569c5f512df11e5619139f7243..84c50d47f3e109124bfcadef29a15557ad230aa6 100644
--- a/src/mednet/libs/segmentation/scripts/predict.py
+++ b/src/mednet/libs/segmentation/scripts/predict.py
@@ -2,86 +2,40 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import json
-import pathlib
-
+import clapper.click
+import clapper.logging
 import click
-from clapper.click import ResourceOption, verbosity_option
-from clapper.logging import setup
-from mednet.libs.common.scripts.click import ConfigCommand
-from mednet.libs.common.scripts.predict import reusable_options
-from PIL import Image
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-def _save_hdf5(
-    img: Image,
-    target: Image,
-    mask: Image,
-    hdf5_path: pathlib.Path,
-) -> None:
-    """Save prediction image, target and mask in an hdf5 file.
-
-    Parameters
-    ----------
-    img
-        Monochrome Image with prediction maps.
-    target
-        Target corresponding to the prediction.
-    mask
-        Mask corresponding to the prediction.
-    hdf5_path
-        File in which to save the data.
-    """
-
-    import h5py
-    from tqdm import tqdm
-
-    tqdm.write(f"Saving {hdf5_path}...")
-    hdf5_path.parent.mkdir(parents=True, exist_ok=True)
-    with h5py.File(hdf5_path, "w") as f:
-        f.create_dataset(
-            "img",
-            data=img.squeeze(0),
-            compression="gzip",
-            compression_opts=9,
-        )
-        f.create_dataset(
-            "target",
-            data=target.squeeze(0),
-            compression="gzip",
-            compression_opts=9,
-        )
-        f.create_dataset(
-            "mask",
-            data=mask.squeeze(0),
-            compression="gzip",
-            compression_opts=9,
-        )
+import mednet.libs.common.scripts.click
+import mednet.libs.common.scripts.predict
+
+logger = clapper.logging.setup(
+    __name__.split(".")[0], format="%(levelname)s: %(message)s"
+)
 
 
 @click.command(
     entry_point_group="mednet.libs.segmentation.config",
-    cls=ConfigCommand,
+    cls=mednet.libs.common.scripts.click.ConfigCommand,
     epilog="""Examples:
 
 1. Run prediction on an existing DataModule configuration:
 
    .. code:: sh
 
-      mednet segmentation predict -vv lwnet drive --weight=path/to/model.ckpt --output=path/to/predictions.json
+      mednet segmentation predict -vv lwnet drive --weight=path/to/model.ckpt --output-folder=path/to/predictions
 
 2. Enable multi-processing data loading with 6 processes:
 
    .. code:: sh
 
-      mednet segmentation predict -vv lwnet drive --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
+      mednet segmentation predict -vv lwnet drive --parallel=6 --weight=path/to/model.ckpt --output-folder=path/to/predictions
 
 """,
 )
-@reusable_options
-@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
+@mednet.libs.common.scripts.predict.reusable_options
+@clapper.click.verbosity_option(
+    logger=logger, cls=clapper.click.ResourceOption, expose_value=False
+)
 def predict(
     output_folder,
     model,
@@ -94,44 +48,22 @@ def predict(
 ) -> None:  # numpydoc ignore=PR01
     """Run inference (generates scores) on all input images, using a pre-trained model."""
 
-    import typing
-
-    from mednet.libs.classification.engine.predictor import run
     from mednet.libs.common.engine.device import DeviceManager
     from mednet.libs.common.scripts.predict import (
         load_checkpoint,
         save_json_data,
         setup_datamodule,
     )
+    from mednet.libs.common.scripts.utils import save_json_with_backup
+    from mednet.libs.segmentation.engine.predictor import run
 
-    predictions_file = output_folder / "predictions.json"
+    predictions_meta_file = output_folder / "predictions.meta.json"
 
     setup_datamodule(datamodule, model, batch_size, parallel)
     model = load_checkpoint(model, weight)
     device_manager = DeviceManager(device)
-    save_json_data(datamodule, model, predictions_file, device_manager)
-
-    predictions = run(model, datamodule, device_manager)
-
-    # Save image data (sample, target, mask) into an hdf5 file
-    json_predictions = {}
-    assert isinstance(
-        predictions, typing.Mapping
-    ), "predictions must be a dictionary or this program will not work!"
-    for split_name, split in predictions.items():
-        pred_paths = []
-        for sample in split:
-            hdf5_path = pathlib.Path(f"{sample[0]}").with_suffix(".hdf5")
-            _save_hdf5(
-                sample[3].numpy(),  # float32
-                sample[1].numpy() > 0.5,  # boolean
-                sample[2].numpy() > 0.5,  # boolean
-                output_folder / hdf5_path,
-            )
-            pred_paths.append([str(sample[0]), str(hdf5_path)])
-        json_predictions[split_name] = pred_paths
-
-    # Save path to hdf5 files into predictions.json
-    with predictions_file.open("w") as f:
-        json.dump(json_predictions, f, indent=2)
-    logger.info(f"Predictions saved to `{str(predictions_file)}`")
+    save_json_data(datamodule, model, device_manager, predictions_meta_file)
+
+    json_predictions = run(model, datamodule, device_manager, output_folder)
+
+    save_json_with_backup(output_folder / "predictions.json", json_predictions)