diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 2da897a7753c474e2be19edff907e13597b012b5..8669718ea8efc5085c090af1db8b82fd89e16e4f 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -2,10 +2,7 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import csv
 import logging
-import os
-import pathlib
 import time
 import typing
 
@@ -380,49 +377,3 @@ class LoggingCallback(lightning.pytorch.Callback):
                     {k: self._to_log[k], "step": float(trainer.current_epoch)}
                 )
             self._to_log = {}
-
-
-class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
-    """Lightning callback to write predictions to a file."""
-
-    def __init__(
-        self,
-        output_dir: str | pathlib.Path,
-        logfile_fields: typing.Sequence[str],
-        write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"],
-    ):
-        super().__init__(write_interval)
-        self.output_dir = output_dir
-        self.logfile_fields = logfile_fields
-
-    def write_on_epoch_end(
-        self,
-        trainer: lightning.pytorch.Trainer,
-        pl_module: lightning.pytorch.LightningModule,
-        predictions: typing.Sequence[typing.Any],
-        batch_indices: typing.Sequence[typing.Any] | None,
-    ) -> None:
-        dataloader_names = list(trainer.datamodule.predict_dataloader().keys())
-
-        for dataloader_idx, dataloader_name in enumerate(dataloader_names):
-            logfile = os.path.join(
-                self.output_dir,
-                f"{dataloader_name}.csv",
-            )
-            os.makedirs(os.path.dirname(logfile), exist_ok=True)
-
-            logger.info(f"Saving predictions in {logfile}.")
-
-            with open(logfile, "w") as l_f:
-                logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
-                logwriter.writeheader()
-
-                for prediction in predictions[dataloader_idx]:
-                    logwriter.writerow(
-                        {
-                            "filename": prediction[0],
-                            "likelihood": prediction[1].numpy(),
-                            "ground_truth": prediction[2].numpy(),
-                        }
-                    )
-                l_f.flush()
diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index 6bb6e275d304406182449c77a68a8a8e62406719..dd515789a0218333b3404fa97c3b8a3964de6039 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -3,13 +3,10 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import logging
-import os
+import pathlib
 
 import lightning.pytorch
 
-from lightning.pytorch import Trainer
-
-from .callbacks import PredictionsWriter
 from .device import DeviceManager
 
 logger = logging.getLogger(__name__)
@@ -19,55 +16,64 @@ def run(
     model: lightning.pytorch.LightningModule,
     datamodule: lightning.pytorch.LightningDataModule,
     device_manager: DeviceManager,
-    output_folder: str,
-):
+    output_folder: pathlib.Path,
+) -> dict[str, list] | list | list[list] | None:
     """Runs inference on input data, outputs csv files with predictions.
 
     Parameters
     ---------
-    model : :py:class:`torch.nn.Module`
+    model
         Neural network model (e.g. pasa).
-
     datamodule
         The lightning datamodule to use for training **and** validation
-
     device_manager
         An internal device representation, to be used for training and
         validation.  This representation can be converted into a pytorch device
         or a torch lightning accelerator setup.
-
     output_folder : str
-        Directory in which the results will be saved.
+        Directory in which the logs will be saved.
 
-    grad_cams : bool
-        If we export grad cams for every prediction (must be used along
-        a batch size of 1 with the DensenetRS model).
 
     Returns
     -------
-
-    all_predictions : list
-        All the predictions associated with filename and ground truth.
+    predictions
+        A dictionary containing the predictions for each of the input samples
+        per dataloader.  Keys correspond to the original split names defined at
+        the loader.  If the datamodule's ``predict_dataloader()`` method does
+        not return a dictionary, then its output is directly passed to the
+        trainer ``predict()`` method.
     """
 
-    logger.info(f"Output folder: {output_folder}")
-    os.makedirs(output_folder, exist_ok=True)
+    from .loggers.custom_tensorboard_logger import CustomTensorboardLogger
 
-    logfile_fields = ("filename", "likelihood", "ground_truth")
+    log_dir = "logs"
+    tensorboard_logger = CustomTensorboardLogger(
+        output_folder,
+        log_dir,
+    )
+    logger.info(
+        f"Monitor prediction with `tensorboard serve "
+        f"--logdir={output_folder}/{log_dir}/`. "
+        f"Then, open a browser on the printed address."
+    )
 
     accelerator, devices = device_manager.lightning_accelerator()
-    trainer = Trainer(
+    trainer = lightning.pytorch.Trainer(
         accelerator=accelerator,
         devices=devices,
-        callbacks=[
-            PredictionsWriter(
-                output_dir=output_folder,
-                logfile_fields=logfile_fields,
-                write_interval="epoch",
-            ),
-        ],
+        logger=tensorboard_logger,
     )
 
-    all_predictions = trainer.predict(model, datamodule)
-
-    return all_predictions
+    dataloaders = datamodule.predict_dataloader()
+    if isinstance(dataloaders, dict):
+        retval = {}
+        for name, dataloader in dataloaders.items():
+            logger.info(f"Running prediction on `{name}` split...")
+            predictions = trainer.predict(model, dataloader)
+            retval[name] = [
+                sample for batch in predictions for sample in batch  # type: ignore
+            ]
+        return retval
+
+    # just pass all the loaders to the trainer, let it handle
+    return trainer.predict(model, datamodule)
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 825ac8cd578f993e2bc874c06ad0cbdeeee0eefd..86f282523dbcfc5edc05fb6ae516e2b05ecea17e 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -164,7 +164,7 @@ def run(
         log_dir,
     )
     logger.info(
-        f"Monitor experiment with `tensorboard serve "
+        f"Monitor training with `tensorboard serve "
         f"--logdir={output_folder}/{log_dir}/`. "
         f"Then, open a browser on the printed address."
     )
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 9096f6081366ce8983a181caa2d2500f118ef82a..a664c4c6021768319e82fc0a7ff2a1e66cb8a161 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -14,6 +14,7 @@ import torchvision.models as models
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .separate import separate
 from .transforms import RGB
 from .typing import Checkpoint
 
@@ -205,18 +206,9 @@ class Alexnet(pl.LightningModule):
         return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        images = batch[0]
-        labels = batch[1]["label"]
-        names = batch[1]["name"]
-
-        outputs = self(images)
+        outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
-
-        return (
-            names[0],
-            torch.flatten(probabilities),
-            torch.flatten(labels),
-        )
+        return separate((probabilities, batch[1]))
 
     def configure_optimizers(self):
         return self._optimizer_type(
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 97ebaf78e51e69fc8b732be1e10b69cae2d09231..c416bb0298bd3782e864ce1b0bc26696d3dc68ea 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -14,6 +14,7 @@ import torchvision.models as models
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .separate import separate
 from .transforms import RGB
 from .typing import Checkpoint
 
@@ -199,18 +200,9 @@ class Densenet(pl.LightningModule):
         return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        images = batch[0]
-        labels = batch[1]["label"]
-        names = batch[1]["name"]
-
-        outputs = self(images)
+        outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
-
-        return (
-            names[0],
-            torch.flatten(probabilities),
-            torch.flatten(labels),
-        )
+        return separate((probabilities, batch[1]))
 
     def configure_optimizers(self):
         return self._optimizer_type(
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index d0f952bbf9230206c0f66a0c342b9f3cf52b83c0..4e0338f4e0b4ceb6faeef11bedba73a9b2b0204b 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -8,6 +8,8 @@ import lightning.pytorch as pl
 import torch
 import torch.nn as nn
 
+from .separate import separate
+
 
 class LogisticRegression(pl.LightningModule):
     """Logistic regression classifier with a single output.
@@ -106,18 +108,9 @@ class LogisticRegression(pl.LightningModule):
             return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch[0]
-        input = batch[1]
-
-        output = self(input)
-        probabilities = torch.sigmoid(output)
-
-        # necessary check for HED architecture that uses several outputs
-        # for loss calculation instead of just the last concatfuse block
-        if isinstance(output, list):
-            output = output[-1]
-
-        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+        outputs = self(batch[0])
+        probabilities = torch.sigmoid(outputs)
+        return separate((probabilities, batch[1]))
 
     def configure_optimizers(self):
         return self._optimizer_type(
diff --git a/src/ptbench/models/mlp.py b/src/ptbench/models/mlp.py
index 54c7cb69196b109773ad65140ce680d8f16d2bcb..102b384985c1814288e17265e3265015f500aacd 100644
--- a/src/ptbench/models/mlp.py
+++ b/src/ptbench/models/mlp.py
@@ -7,6 +7,8 @@ import typing
 import lightning.pytorch as pl
 import torch
 
+from .separate import separate
+
 
 class MultiLayerPerceptron(pl.LightningModule):
     """MLP with a variable number of inputs and hidden neurons (single layer).
@@ -111,18 +113,9 @@ class MultiLayerPerceptron(pl.LightningModule):
             return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch[0]
-        input = batch[1]
-
-        output = self(input)
-        probabilities = torch.sigmoid(output)
-
-        # necessary check for HED architecture that uses several outputs
-        # for loss calculation instead of just the last concatfuse block
-        if isinstance(output, list):
-            output = output[-1]
-
-        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
+        outputs = self(batch[0])
+        probabilities = torch.sigmoid(outputs)
+        return separate((probabilities, batch[1]))
 
     def configure_optimizers(self):
         return self._optimizer_type(
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 5d6e20b4dd25c8ba6d57e3c681cdc2514af616ee..c89a6d328a229dd7839956e2105e42541c1df0ea 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -14,6 +14,7 @@ import torch.utils.data
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .separate import separate
 from .transforms import Grayscale
 from .typing import Checkpoint
 
@@ -265,18 +266,9 @@ class Pasa(pl.LightningModule):
         return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        images = batch[0]
-        labels = batch[1]["label"]
-        names = batch[1]["name"]
-
-        outputs = self(images)
+        outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
-
-        return (
-            names[0],
-            torch.flatten(probabilities),
-            torch.flatten(labels),
-        )
+        return separate((probabilities, batch[1]))
 
     def configure_optimizers(self):
         return self._optimizer_type(
diff --git a/src/ptbench/models/separate.py b/src/ptbench/models/separate.py
new file mode 100644
index 0000000000000000000000000000000000000000..522d5b94d2313a34c23cb6322c1d986dc3631217
--- /dev/null
+++ b/src/ptbench/models/separate.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Contains the inverse :py:func:`torch.utils.data.default_collate`."""
+
+import torch
+
+from ..data.typing import Sample
+
+
+def separate(batch: Sample) -> list[Sample]:
+    """Separates a collated batch reconstituting its samples.
+
+    This function implements the inverse of
+    :py:func:`torch.utils.data.default_collate`, and can separate, into
+    samples, batches of data with different attributes.  It follows the inverse
+    path of that function, and implements the following separation algorithms:
+
+    * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer
+      dimension, via :py:func:`torch.flatten`)
+    * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]``
+    """
+
+    # as of now, this is really simple - to be made more complex upon need.
+    metadata = [
+        {key: value[i] for key, value in batch[1].items()}
+        for i in range(len(batch[0]))
+    ]
+    return list(zip(torch.flatten(batch[0]), metadata))
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index b5ed123c2da9df3b1c485b999ac5eb9bfa04c9ee..551422fde176a9095ec3ac7dd3f69c6e174161aa 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -2,6 +2,8 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import pathlib
+
 import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
@@ -15,49 +17,61 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ConfigCommand,
     epilog="""Examples:
 
-\b
-    1. Runs prediction on an existing dataset configuration:
+    1. Runs prediction on an existing datamodule configuration:
 
        .. code:: sh
 
-          ptbench predict -vv pasa montgomery --weight=path/to/model_final.pth --output-folder=path/to/predictions
+          \b
+          ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json
+
+    2. Enables multi-processing data loading with 6 processes:
+
+       .. code:: sh
+
+          \b
+          ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json
 
 """,
 )
 @click.option(
-    "--output-folder",
+    "--output",
     "-o",
-    help="Path where to store the predictions (created if does not exist)",
+    help="""Path where to store the JSON predictions for all samples in the
+    input datamodule (leading directories are created if they do not not
+    exist).""",
     required=True,
     default="results",
     cls=ResourceOption,
-    type=click.Path(),
+    type=click.Path(
+        file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path
+    ),
 )
 @click.option(
     "--model",
     "-m",
-    help="A torch.nn.Module instance implementing the network to be evaluated",
+    help="""A lightining module instance implementing the network architecture
+    (not the weights, necessarily) to be used for prediction.""",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--datamodule",
     "-d",
-    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
-    "to be used for running prediction, possibly including all pre-processing "
-    "pipelines required or, optionally, a dictionary mapping string keys to "
-    "torch.utils.data.dataset.Dataset instances.  All keys that do not start "
-    "with an underscore (_) will be processed.",
+    help="""A lighting data module that will be asked for prediction data
+    loaders. Typically, this includes all configured splits in a datamodule,
+    however this is not a requirement.  A datamodule that returns a single
+    dataloader for prediction (wrapped in a dictionary) is acceptable.""",
     required=True,
     cls=ResourceOption,
 )
 @click.option(
     "--batch-size",
     "-b",
-    help="Number of samples in every batch (this parameter affects memory requirements for the network)",
+    help="""Number of samples in every batch (this parameter affects memory
+    requirements for the network).""",
     required=True,
     show_default=True,
-    default=1,
+    default=10,
     type=click.IntRange(min=1),
     cls=ResourceOption,
 )
@@ -73,54 +87,76 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--weight",
     "-w",
-    help="Path or URL to pretrained model file (.ckpt extension)",
+    help="""Path or URL to pretrained model file (`.ckpt` extension),
+    corresponding to the architecture set with `--model`.""",
+    required=True,
+    cls=ResourceOption,
+    type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
+)
+@click.option(
+    "--parallel",
+    "-P",
+    help="""Use multiprocessing for data loading: if set to -1 (default),
+    disables multiprocessing data loading.  Set to 0 to enable as many data
+    loading instances as processing cores as available in the system.  Set to
+    >= 1 to enable that many multiprocessing instances for data loading.""",
+    type=click.IntRange(min=-1),
+    show_default=True,
     required=True,
+    default=-1,
     cls=ResourceOption,
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def predict(
-    output_folder,
+    output,
     model,
     datamodule,
     batch_size,
     device,
     weight,
+    parallel,
     **_,
 ) -> None:
     """Predicts Tuberculosis presence (probabilities) on input images."""
 
-    import os
-
-    import numpy as np
-
-    from matplotlib.backends.backend_pdf import PdfPages
+    import json
+    import shutil
 
     from ..engine.device import DeviceManager
     from ..engine.predictor import run
-    from ..utils.plot import relevance_analysis_plot
 
     datamodule.set_chunk_size(batch_size, 1)
+    datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
 
     datamodule.prepare_data()
     datamodule.setup(stage="predict")
 
-    logger.info(f"Loading checkpoint from {weight}")
+    logger.info(f"Loading checkpoint from `{weight}`...")
     model = model.load_from_checkpoint(weight, strict=False)
 
-    # Logistic regressor weights
-    if model.name == "logistic_regression":
-        logger.info("Logistic regression identified: saving model weights")
-        for param in model.parameters():
-            model_weights = np.array(param.data.reshape(-1))
-            break
-        filepath = os.path.join(output_folder, "LogReg_Weights.pdf")
-        logger.info(f"Creating and saving weights plot at {filepath}...")
-        os.makedirs(os.path.dirname(filepath), exist_ok=True)
-        pdf = PdfPages(filepath)
-        pdf.savefig(
-            relevance_analysis_plot(model_weights, title="LogReg model weights")
-        )
-        pdf.close()
+    predictions = run(model, datamodule, DeviceManager(device), output.parent)
 
-    _ = run(model, datamodule, DeviceManager(device), output_folder)
+    output.parent.mkdir(parents=True, exist_ok=True)
+    if output.exists():
+        backup = output.parent / (output.name + "~")
+        logger.warning(
+            f"Output predictions file `{str(output)}` exists - "
+            f"backing it up to `{str(backup)}`..."
+        )
+        shutil.copy(output, backup)
+
+    with output.open("w") as f:
+        flat_predictions: dict[str, list[list]] = {}
+        # creates a flat representation of predictions that is similar to our
+        # own JSON split files
+        for split_name, split_values in predictions.items():  # type: ignore
+            flat_predictions.setdefault(
+                split_name,
+                [
+                    [v[1]["name"], v[1]["label"].item(), v[0].item()]
+                    for v in split_values
+                ],
+            )
+        json.dump(flat_predictions, f, indent=2)
+    logger.info(f"Predictions saved to `{str(output)}`")