diff --git a/src/mednet/libs/classification/engine/predictor.py b/src/mednet/libs/classification/engine/predictor.py
index 771d918f7ef454656038380d743d2da45e7745cd..d0298a3ee95652cc90ffeb2db846d51c8b46fedd 100644
--- a/src/mednet/libs/classification/engine/predictor.py
+++ b/src/mednet/libs/classification/engine/predictor.py
@@ -3,6 +3,7 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import logging
+import typing
 
 import lightning.pytorch
 import torch.utils.data
@@ -18,6 +19,79 @@ from ..models.typing import (
 logger = logging.getLogger("mednet")
 
 
+class _JSONMetadataCollector(lightning.pytorch.callbacks.BasePredictionWriter):
+    """Collects further sample metadata to store with predictions.
+
+    This object collects further sample metadata we typically keep with
+    predictions.
+
+    Parameters
+    ----------
+    write_interval
+        When will this callback be active.
+    """
+
+    def __init__(
+        self,
+        write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"] = "batch",
+    ):
+        super().__init__(write_interval=write_interval)
+        self._data: list[BinaryPrediction] | list[MultiClassPrediction] = []
+
+    def write_on_batch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+        prediction: typing.Any,
+        batch_indices: typing.Sequence[int] | None,
+        batch: typing.Any,
+        batch_idx: int,
+        dataloader_idx: int,
+    ) -> None:
+        """Write batch predictions to disk.
+
+        Parameters
+        ----------
+        trainer
+            The trainer being used.
+        pl_module
+            The pytorch module.
+        prediction
+            The actual predictions to record.
+        batch_indices
+            The relative position of samples on the epoch.
+        batch
+            The current batch.
+        batch_idx
+            Index of the batch overall.
+        dataloader_idx
+            Index of the dataloader overall.
+        """
+        for k, sample_pred in enumerate(prediction):
+            sample_name: str = batch[1]["name"][k]
+            target_shape = batch[1]["target"][k].shape
+            self._data.append(
+                (
+                    sample_name,
+                    batch[1]["target"][k].cpu().numpy().tolist(),
+                    sample_pred.cpu().numpy().reshape(target_shape).tolist(),
+                )
+            )
+
+    def reset(self) -> list[BinaryPrediction] | list[MultiClassPrediction]:
+        """Summary of written objects.
+
+        Also resets the internal state.
+
+        Returns
+        -------
+            A list containing a summary of all samples written.
+        """
+        retval = self._data
+        self._data = []
+        return retval
+
+
 def run(
     model: lightning.pytorch.LightningModule,
     datamodule: lightning.pytorch.LightningDataModule,
@@ -77,35 +151,38 @@ def run(
 
     from lightning.pytorch.loggers.logger import DummyLogger
 
+    collector = _JSONMetadataCollector()
+
     accelerator, devices = device_manager.lightning_accelerator()
     trainer = lightning.pytorch.Trainer(
         accelerator=accelerator,
         devices=devices,
         logger=DummyLogger(),
+        callbacks=[collector],
     )
 
-    def _flatten(p: list[list]):
-        return [sample for batch in p for sample in batch]
-
     dataloaders = datamodule.predict_dataloader()
 
     if isinstance(dataloaders, torch.utils.data.DataLoader):
         logger.info("Running prediction on a single dataloader...")
-        return _flatten(trainer.predict(model, dataloaders))  # type: ignore
+        trainer.predict(model, dataloaders, return_predictions=False)
+        return collector.reset()
 
     if isinstance(dataloaders, list):
         retval_list = []
         for k, dataloader in enumerate(dataloaders):
             logger.info(f"Running prediction on split `{k}`...")
-            retval_list.append(_flatten(trainer.predict(model, dataloader)))  # type: ignore
-        return retval_list
+            trainer.predict(model, dataloader, return_predictions=False)
+            retval_list.append(collector.reset())
+        return retval_list  # type: ignore
 
     if isinstance(dataloaders, dict):
         retval_dict = {}
         for name, dataloader in dataloaders.items():
             logger.info(f"Running prediction on `{name}` split...")
-            retval_dict[name] = _flatten(trainer.predict(model, dataloader))  # type: ignore
-        return retval_dict
+            trainer.predict(model, dataloader, return_predictions=False)
+            retval_dict[name] = collector.reset()
+        return retval_dict  # type: ignore
 
     if dataloaders is None:
         logger.warning("Datamodule did not return any prediction dataloaders!")
diff --git a/src/mednet/libs/classification/models/alexnet.py b/src/mednet/libs/classification/models/alexnet.py
index eab264185b16343936e0daf2d6f1935190ca15f3..7e20c5943a25a2852f7c38e892af04f91262f977 100644
--- a/src/mednet/libs/classification/models/alexnet.py
+++ b/src/mednet/libs/classification/models/alexnet.py
@@ -13,8 +13,6 @@ import torchvision.models as models
 from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.common.models.model import Model
 
-from .separate import separate
-
 logger = logging.getLogger("mednet")
 
 
@@ -117,13 +115,7 @@ class Alexnet(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            from .normalizer import make_z_normalizer
-
-            logger.info(
-                f"Uninitialised {self.name} model - "
-                f"computing z-norm factors from train dataloader.",
-            )
-            self.normalizer = make_z_normalizer(dataloader)
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, _):
         images = batch[0]
@@ -160,5 +152,4 @@ class Alexnet(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
-        probabilities = torch.sigmoid(outputs)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/densenet.py b/src/mednet/libs/classification/models/densenet.py
index 428ba28187328822c3eb241468ea09249d1786ff..5f41f1aebfbc8fb4b2315996a33dd39f4301d3c6 100644
--- a/src/mednet/libs/classification/models/densenet.py
+++ b/src/mednet/libs/classification/models/densenet.py
@@ -13,8 +13,6 @@ import torchvision.models as models
 from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.common.models.model import Model
 
-from .separate import separate
-
 logger = logging.getLogger("mednet")
 
 
@@ -120,13 +118,7 @@ class Densenet(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            from .normalizer import make_z_normalizer
-
-            logger.info(
-                f"Uninitialised {self.name} model - "
-                f"computing z-norm factors from train dataloader.",
-            )
-            self.normalizer = make_z_normalizer(dataloader)
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, _):
         images = batch[0]
@@ -158,5 +150,4 @@ class Densenet(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
-        probabilities = torch.sigmoid(outputs)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/logistic_regression.py b/src/mednet/libs/classification/models/logistic_regression.py
index f203e35221671f6abb8a9db87f12d3d82f87201f..f9fe1847be16de65dc3dfbd88223b89db6eeec97 100644
--- a/src/mednet/libs/classification/models/logistic_regression.py
+++ b/src/mednet/libs/classification/models/logistic_regression.py
@@ -8,8 +8,6 @@ import lightning.pytorch as pl
 import torch
 import torch.nn as nn
 
-from .separate import separate
-
 
 class LogisticRegression(pl.LightningModule):
     """Logistic regression classifier with a single output.
@@ -62,7 +60,7 @@ class LogisticRegression(pl.LightningModule):
         self.linear = nn.Linear(input_size, 1)
 
     def forward(self, x):
-        return self.linear(x)
+        return self.linear(self.normalizer(x))
 
     def training_step(self, batch, batch_idx):
         _input = batch[1]
@@ -105,8 +103,7 @@ class LogisticRegression(pl.LightningModule):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
-        probabilities = torch.sigmoid(outputs)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(outputs)
 
     def configure_optimizers(self):
         return self._optimizer_type(
diff --git a/src/mednet/libs/classification/models/mlp.py b/src/mednet/libs/classification/models/mlp.py
index e8e4b2904264d8fc9485ad28adbf4b90213f51eb..bd928410a3133141a0048de8537df4bf65cd8e49 100644
--- a/src/mednet/libs/classification/models/mlp.py
+++ b/src/mednet/libs/classification/models/mlp.py
@@ -7,8 +7,6 @@ import typing
 import lightning.pytorch as pl
 import torch
 
-from .separate import separate
-
 
 class MultiLayerPerceptron(pl.LightningModule):
     """MLP with a variable number of inputs and hidden neurons (single layer).
@@ -66,7 +64,7 @@ class MultiLayerPerceptron(pl.LightningModule):
         self.fc2 = torch.nn.Linear(hidden_size, 1)
 
     def forward(self, x):
-        return self.fc2(self.relu(self.fc1(x)))
+        return self.fc2(self.relu(self.fc1(self.normalizer(x))))
 
     def training_step(self, batch, batch_idx):
         _input = batch[1]
@@ -109,8 +107,7 @@ class MultiLayerPerceptron(pl.LightningModule):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
-        probabilities = torch.sigmoid(outputs)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(outputs)
 
     def configure_optimizers(self):
         return self._optimizer_type(
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index 478c9397c30d506cc59c08107dcad5a8dd4e1ffb..37415a7690a128728d5cd280a47ebef194114330 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -13,8 +13,6 @@ import torch.utils.data
 from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.common.models.model import Model
 
-from .separate import separate
-
 logger = logging.getLogger("mednet")
 
 
@@ -223,5 +221,4 @@ class Pasa(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
-        probabilities = torch.sigmoid(outputs)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(outputs)
diff --git a/src/mednet/libs/classification/models/separate.py b/src/mednet/libs/classification/models/separate.py
deleted file mode 100644
index 9b575e8ee1a39258b2437e1c968b303aff7b028b..0000000000000000000000000000000000000000
--- a/src/mednet/libs/classification/models/separate.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""Contains the inverse :py:func:`torch.utils.data.default_collate`."""
-
-import typing
-
-import torch
-from mednet.libs.common.data.typing import Sample
-
-from .typing import BinaryPrediction, MultiClassPrediction
-
-
-def _as_predictions(
-    samples: typing.Iterable[Sample],
-) -> list[BinaryPrediction | MultiClassPrediction]:
-    """Take a list of separated batch predictions and transforms it into a list
-    of formal predictions.
-
-    Parameters
-    ----------
-    samples
-        A sequence of samples as returned by :py:func:`separate`.
-
-    Returns
-    -------
-    list[BinaryPrediction | MultiClassPrediction]
-        A list of typed predictions that can be saved to disk.
-    """
-
-    return [(v[1]["name"], v[1]["target"].item(), v[0].item()) for v in samples]
-
-
-def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
-    """Separate a collated batch, reconstituting its samples.
-
-    This function implements the inverse of
-    :py:func:`torch.utils.data.default_collate`, and can separate, into
-    samples, batches of data with different attributes.  It follows the inverse
-    path of that function, and implements the following separation algorithms:
-
-    * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer
-      dimension, via :py:func:`torch.flatten`)
-    * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]``
-
-    Parameters
-    ----------
-    batch
-        A batch, as output by torch model forwarding.
-
-    Returns
-    -------
-        A list of predictions that contains the predictions and associated metadata
-        for each processed sample.
-    """
-
-    # as of now, this is really simple - to be made more complex upon need.
-    metadata = [
-        {key: value[i] for key, value in batch[1].items()} for i in range(len(batch[0]))
-    ]
-    return _as_predictions(zip(torch.flatten(batch[0]), metadata))
diff --git a/src/mednet/libs/segmentation/engine/predictor.py b/src/mednet/libs/segmentation/engine/predictor.py
index 8371c9a4dc754d875d306d500d45f418913b2f5a..0e5c290f1ed28e975dc6a3bd702353039ba0ae34 100644
--- a/src/mednet/libs/segmentation/engine/predictor.py
+++ b/src/mednet/libs/segmentation/engine/predictor.py
@@ -38,7 +38,7 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
     ):
         super().__init__(write_interval=write_interval)
         self.output_folder = output_folder
-        self._written: list[list[str]] = []
+        self._data: list[tuple[str, str]] = []
 
     def write_on_batch_end(
         self,
@@ -69,39 +69,40 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
         dataloader_idx
             Index of the dataloader overall.
         """
-        for k, p in enumerate(prediction):
-            stem = pathlib.Path(p[0]).with_suffix(".hdf5")
+        for k, sample_pred in enumerate(prediction):
+            sample_name: str = batch[1]["name"][k]
+            stem = pathlib.Path(sample_name).with_suffix(".hdf5")
             output_path = self.output_folder / stem
-            tqdm.tqdm.write(f"`{p[0]}` -> `{str(output_path)}`")
+            tqdm.tqdm.write(f"`{sample_name}` -> `{str(output_path)}`")
             output_path.parent.mkdir(parents=True, exist_ok=True)
             with h5py.File(output_path, "w") as f:
                 f.create_dataset(
                     "image",
-                    data=batch[0][k].numpy(),
+                    data=batch[0][k].cpu().numpy(),
                     compression="gzip",
                     compression_opts=9,
                 )
                 f.create_dataset(
                     "prediction",
-                    data=p[3].numpy().squeeze(0),
+                    data=sample_pred.cpu().numpy().squeeze(0),
                     compression="gzip",
                     compression_opts=9,
                 )
                 f.create_dataset(
                     "target",
-                    data=(batch[1]["target"][k].squeeze(0).numpy() > 0.5),
+                    data=(batch[1]["target"][k].squeeze(0).cpu().numpy() > 0.5),
                     compression="gzip",
                     compression_opts=9,
                 )
                 f.create_dataset(
                     "mask",
-                    data=(batch[1]["mask"][k].squeeze(0).numpy() > 0.5),
+                    data=(batch[1]["mask"][k].squeeze(0).cpu().numpy() > 0.5),
                     compression="gzip",
                     compression_opts=9,
                 )
-            self._written.append([p[0], str(stem)])
+            self._data.append((sample_name, str(stem)))
 
-    def written(self) -> list[list[str]]:
+    def reset(self) -> list[tuple[str, str]]:
         """Summary of written objects.
 
         Also resets the internal state.
@@ -110,8 +111,8 @@ class _HDF5Writer(lightning.pytorch.callbacks.BasePredictionWriter):
         -------
             A list containing a summary of all samples written.
         """
-        retval = self._written
-        self._written = []
+        retval = self._data
+        self._data = []
         return retval
 
 
@@ -120,7 +121,12 @@ def run(
     datamodule: lightning.pytorch.LightningDataModule,
     device_manager: DeviceManager,
     output_folder: pathlib.Path,
-) -> dict[str, list[list[str]]] | list[list[list[str]]] | list[list[str]] | None:
+) -> (
+    dict[str, list[tuple[str, str]]]
+    | list[list[tuple[str, str]]]
+    | list[tuple[str, str]]
+    | None
+):
     """Run inference on input data, output predictions.
 
     Parameters
@@ -154,14 +160,14 @@ def run(
 
     from lightning.pytorch.loggers.logger import DummyLogger
 
-    writer = _HDF5Writer(output_folder)
+    collector = _HDF5Writer(output_folder)
 
     accelerator, devices = device_manager.lightning_accelerator()
     trainer = lightning.pytorch.Trainer(
         accelerator=accelerator,
         devices=devices,
         logger=DummyLogger(),
-        callbacks=[writer],
+        callbacks=[collector],
     )
 
     dataloaders = datamodule.predict_dataloader()
@@ -169,14 +175,14 @@ def run(
     if isinstance(dataloaders, torch.utils.data.DataLoader):
         logger.info("Running prediction on a single dataloader...")
         trainer.predict(model, dataloaders, return_predictions=False)
-        return writer.written()
+        return collector.reset()
 
     if isinstance(dataloaders, list):
         retval_list = []
         for k, dataloader in enumerate(dataloaders):
             logger.info(f"Running prediction on split `{k}`...")
             trainer.predict(model, dataloader, return_predictions=False)
-            retval_list.append(writer.written())
+            retval_list.append(collector.reset())
         return retval_list
 
     if isinstance(dataloaders, dict):
@@ -184,7 +190,7 @@ def run(
         for name, dataloader in dataloaders.items():
             logger.info(f"Running prediction on `{name}` split...")
             trainer.predict(model, dataloader, return_predictions=False)
-            retval_dict[name] = writer.written()
+            retval_dict[name] = collector.reset()
         return retval_dict
 
     if dataloaders is None:
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index cd77fb28feff97df2f608b63660810be664434b5..6dbdef7880bbdf6c937a96c19739ec36af8817d4 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -14,7 +14,6 @@ from mednet.libs.common.models.model import Model
 from .backbones.vgg import vgg16_for_segmentation
 from .losses import SoftJaccardBCELogitsLoss
 from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
-from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -133,8 +132,7 @@ class DRIU(Model):
         self.head = DRIUHead([64, 128, 256, 512])
 
     def forward(self, x):
-        if self.normalizer is not None:
-            x = self.normalizer(x)
+        x = self.normalizer(x)
         x = self.backbone(x)
         return self.head(x)
 
@@ -160,7 +158,7 @@ class DRIU(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            self.normalizer = None
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
         images = batch[0]
@@ -180,8 +178,7 @@ class DRIU(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         output = self(batch[0])[1]
-        probabilities = torch.sigmoid(output)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(output)
 
     def configure_optimizers(self):
         return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 4c19c267c92ffc153f53e85ca4c5a3690ac8e684..07ddc62508f79d0717edca85171404ff980dcd56 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -14,7 +14,6 @@ from mednet.libs.common.models.model import Model
 from .backbones.vgg import vgg16_for_segmentation
 from .losses import SoftJaccardBCELogitsLoss
 from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
-from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -136,8 +135,7 @@ class DRIUBN(Model):
         self.head = DRIUBNHead([64, 128, 256, 512])
 
     def forward(self, x):
-        if self.normalizer is not None:
-            x = self.normalizer(x)
+        x = self.normalizer(x)
         x = self.backbone(x)
         return self.head(x)
 
@@ -163,7 +161,7 @@ class DRIUBN(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            self.normalizer = None
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
         images = batch[0]
@@ -183,8 +181,7 @@ class DRIUBN(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         output = self(batch[0])[1]
-        probabilities = torch.sigmoid(output)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(output)
 
     def configure_optimizers(self):
         return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index 308f87e19f551119c23765a22142f4671bbbb6e5..1e1802af8f973ac9792a5bb7ab99114629dd929f 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -15,7 +15,6 @@ from .backbones.vgg import vgg16_for_segmentation
 from .driu import ConcatFuseBlock
 from .losses import SoftJaccardBCELogitsLoss
 from .make_layers import UpsampleCropBlock
-from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -118,8 +117,7 @@ class DRIUOD(Model):
         self.head = DRIUODHead([128, 256, 512, 512])
 
     def forward(self, x):
-        if self.normalizer is not None:
-            x = self.normalizer(x)
+        x = self.normalizer(x)
         x = self.backbone(x)
         return self.head(x)
 
@@ -145,7 +143,7 @@ class DRIUOD(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            self.normalizer = None
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
         images = batch[0]
@@ -165,8 +163,7 @@ class DRIUOD(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         output = self(batch[0])[1]
-        probabilities = torch.sigmoid(output)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(output)
 
     def configure_optimizers(self):
         return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index a4cbbc5b41c0a8aea287be949a2391091500abac..29942643e4c6e9f44d2121534ec15d6bca707c43 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -15,7 +15,6 @@ from .backbones.vgg import vgg16_for_segmentation
 from .driu import ConcatFuseBlock
 from .losses import SoftJaccardBCELogitsLoss
 from .make_layers import UpsampleCropBlock
-from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -122,8 +121,7 @@ class DRIUPix(Model):
         self.head = DRIUPIXHead([64, 128, 256, 512])
 
     def forward(self, x):
-        if self.normalizer is not None:
-            x = self.normalizer(x)
+        x = self.normalizer(x)
         x = self.backbone(x)
         return self.head(x)
 
@@ -149,7 +147,7 @@ class DRIUPix(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            self.normalizer = None
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
         images = batch[0]
@@ -169,8 +167,7 @@ class DRIUPix(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         output = self(batch[0])[1]
-        probabilities = torch.sigmoid(output)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(output)
 
     def configure_optimizers(self):
         return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 779c48cfce7f77ea88c2e5c19d58f36c8fd4c457..e3bcd094814ff4981ddead877ae079322b3546d4 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -13,7 +13,6 @@ from mednet.libs.common.models.model import Model
 from .backbones.vgg import vgg16_for_segmentation
 from .losses import MultiSoftJaccardBCELogitsLoss
 from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
-from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -137,8 +136,7 @@ class HED(Model):
         self.head = HEDHead([64, 128, 256, 512, 512])
 
     def forward(self, x):
-        if self.normalizer is not None:
-            x = self.normalizer(x)
+        x = self.normalizer(x)
         x = self.backbone(x)
         return self.head(x)
 
@@ -164,7 +162,7 @@ class HED(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            self.normalizer = None
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
         images = batch[0]
@@ -184,8 +182,7 @@ class HED(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         output = self(batch[0])[1]
-        probabilities = torch.sigmoid(output)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(output)
 
     def configure_optimizers(self):
         return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 7a21ad58c026c66bec17105f89a6ff5665d2b57b..90507e4dcd513e9f0a8d12a368f08a010c9ff77c 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -23,8 +23,6 @@ from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.common.models.model import Model
 from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
 
-from .separate import separate
-
 
 def _conv1x1(in_planes, out_planes, stride=1):
     return torch.nn.Conv2d(
@@ -341,8 +339,9 @@ class LittleWNet(Model):
         )
 
     def forward(self, x):
-        x1 = self.unet1(x)
-        x2 = self.unet2(torch.cat([x, x1], dim=1))
+        xn = self.normalizer(x)
+        x1 = self.unet1(xn)
+        x2 = self.unet2(torch.cat([xn, x1], dim=1))
 
         return x1, x2
 
@@ -364,8 +363,7 @@ class LittleWNet(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         output = self(batch[0])[1]
-        probabilities = torch.sigmoid(output)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(output)
 
     def configure_optimizers(self):
         return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index 24587582641680c419eaddb1a8f5b49aad248f94..d934881ce9082694effad5ed4e0c4a431d51e654 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -13,7 +13,6 @@ from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
 from torchvision.models.mobilenetv2 import InvertedResidual
 
 from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation
-from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -185,8 +184,7 @@ class M2UNET(Model):
         self.head = M2UNetHead(in_channels_list=[16, 24, 32, 96])
 
     def forward(self, x):
-        if self.normalizer is not None:
-            x = self.normalizer(x)
+        x = self.normalizer(x)
         x = self.backbone(x)
         return self.head(x)
 
@@ -212,7 +210,7 @@ class M2UNET(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            self.normalizer = None
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
         images = batch[0]
@@ -232,8 +230,7 @@ class M2UNET(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         output = self(batch[0])[1]
-        probabilities = torch.sigmoid(output)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(output)
 
     def configure_optimizers(self):
         return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/models/separate.py b/src/mednet/libs/segmentation/models/separate.py
deleted file mode 100644
index 4f2628f8b41d4b1fea238b8cbda441200cfe109e..0000000000000000000000000000000000000000
--- a/src/mednet/libs/segmentation/models/separate.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""Contains the inverse :py:func:`torch.utils.data.default_collate`."""
-
-import typing
-
-from mednet.libs.common.data.typing import Sample
-
-from .typing import SegmentationPrediction
-
-
-def _as_predictions(
-    samples: typing.Iterable[Sample],
-) -> list[SegmentationPrediction]:
-    """Take a list of separated batch predictions and transforms it into a list
-    of formal predictions.
-
-    Parameters
-    ----------
-    samples
-        A sequence of samples as returned by :py:func:`separate`.
-
-    Returns
-    -------
-        A list of typed predictions that can be saved to disk.
-    """
-    return [(v[1]["name"], v[1]["target"], v[1]["mask"], v[0]) for v in samples]
-
-
-def separate(batch: Sample) -> list[SegmentationPrediction]:
-    """Separate a collated batch, reconstituting its samples.
-
-    This function implements the inverse of
-    :py:func:`torch.utils.data.default_collate`, and can separate, into
-    samples, batches of data with different attributes.  It follows the inverse
-    path of that function, and implements the following separation algorithms:
-
-    * :class:`torch.Tensor` -> :class:`torch.Tensor` (with a removed outer
-      dimension, via :py:func:`torch.flatten`)
-    * ``typing.Mapping[K, V[]]`` -> ``[dict[K, V_1], dict[K, V_2], ...]``
-
-    Parameters
-    ----------
-    batch
-        A batch, as output by torch model forwarding.
-
-    Returns
-    -------
-        A list of predictions that contains the predictions and associated metadata
-        for each processed sample.
-    """
-
-    # as of now, this is really simple - to be made more complex upon need.
-    metadata = [
-        {key: value[i] for key, value in batch[1].items()} for i in range(len(batch[0]))
-    ]
-
-    return _as_predictions(zip(batch[0], metadata))
diff --git a/src/mednet/libs/segmentation/models/typing.py b/src/mednet/libs/segmentation/models/typing.py
deleted file mode 100644
index 11ec1457ff1ece9fa26d779190c529c0bd78c724..0000000000000000000000000000000000000000
--- a/src/mednet/libs/segmentation/models/typing.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""Defines most common types used in code."""
-
-import pathlib
-import typing
-
-Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
-"""Definition of a lightning checkpoint."""
-
-SegmentationPrediction: typing.TypeAlias = tuple[
-    pathlib.Path, pathlib.Path, pathlib.Path, pathlib.Path
-]
-"""The sample name, the target, mask, and the prediction."""
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index 0943d283cdac1f2b72f306a116de3d3fe579ce5b..04317572c856704091c5b5022a843049c69e53fe 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -13,7 +13,6 @@ from mednet.libs.common.models.model import Model
 from .backbones.vgg import vgg16_for_segmentation
 from .losses import SoftJaccardBCELogitsLoss
 from .make_layers import UnetBlock, conv_with_kaiming_uniform
-from .separate import separate
 
 logger = logging.getLogger("mednet")
 
@@ -126,8 +125,7 @@ class Unet(Model):
         self.head = UNetHead([64, 128, 256, 512, 512], pixel_shuffle=False)
 
     def forward(self, x):
-        if self.normalizer is not None:
-            x = self.normalizer(x)
+        x = self.normalizer(x)
         x = self.backbone(x)
         return self.head(x)
 
@@ -153,7 +151,7 @@ class Unet(Model):
             )
             self.normalizer = make_imagenet_normalizer()
         else:
-            self.normalizer = None
+            super().set_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
         images = batch[0]
@@ -173,8 +171,7 @@ class Unet(Model):
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         output = self(batch[0])[1]
-        probabilities = torch.sigmoid(output)
-        return separate((probabilities, batch[1]))
+        return torch.sigmoid(output)
 
     def configure_optimizers(self):
         return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
diff --git a/src/mednet/libs/segmentation/scripts/utils.py b/src/mednet/libs/segmentation/scripts/utils.py
deleted file mode 100644
index 890d1c5437d80c71a9a6956fa9fdaf7e259f4259..0000000000000000000000000000000000000000
--- a/src/mednet/libs/segmentation/scripts/utils.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""Utilities for command-line scripts."""
-
-import json
-import logging
-import pathlib
-import re
-import shutil
-
-import lightning.pytorch
-import lightning.pytorch.callbacks
-import torch.nn
-from mednet.libs.common.engine.device import SupportedPytorchDevice
-
-logger = logging.getLogger("mednet")
-
-
-def model_summary(
-    model: torch.nn.Module,
-) -> dict[str, int | list[tuple[str, str, int]]]:
-    """Save a little summary of the model in a txt file.
-
-    Parameters
-    ----------
-    model
-        Instance of the model for which to save the summary.
-
-    Returns
-    -------
-    tuple[lightning.pytorch.callbacks.ModelSummary, int]
-        A tuple with the model summary in a text format and number of parameters of the model.
-    """
-
-    s = lightning.pytorch.utilities.model_summary.ModelSummary(  # type: ignore
-        model,
-    )
-
-    return dict(
-        model_summary=list(zip(s.layer_names, s.layer_types, s.param_nums)),
-        model_size=s.total_parameters,
-    )
-
-
-def device_properties(
-    device_type: SupportedPytorchDevice,
-) -> dict[str, int | float | str]:
-    """Generate information concerning hardware properties.
-
-    Parameters
-    ----------
-    device_type
-        The type of compute device we are using.
-
-    Returns
-    -------
-        Static properties of the current machine.
-    """
-
-    from mednet.libs.common.utils.resources import (
-        cpu_constants,
-        cuda_constants,
-        mps_constants,
-    )
-
-    retval: dict[str, int | float | str] = {}
-    retval.update(cpu_constants())
-
-    match device_type:
-        case "cpu":
-            pass
-        case "cuda":
-            results = cuda_constants()
-            if results is not None:
-                retval.update(results)
-        case "mps":
-            results = mps_constants()
-            if results is not None:
-                retval.update(results)
-        case _:
-            pass
-
-    return retval
-
-
-def execution_metadata() -> dict[str, int | float | str | dict[str, str]]:
-    """Produce metadata concerning the running script, in the form of a
-    dictionary.
-
-    This function returns potentially useful metadata concerning program
-    execution.  It contains a certain number of preset variables.
-
-    Returns
-    -------
-        A dictionary that contains the following fields:
-
-        * ``package-name``: current package name (e.g. ``mednet``)
-        * ``package-version``: current package version (e.g. ``1.0.0b0``)
-        * ``datetime``: date and time in ISO8601 format (e.g. ``2024-02-23T18:38:09+01:00``)
-        * ``user``: username (e.g. ``johndoe``)
-        * ``conda-env``: if set, the name of the current conda environment
-        * ``path``: current path when executing the command
-        * ``command-line``: the command-line that is being run
-        * ``hostname``: machine hostname (e.g. ``localhost``)
-        * ``platform``: machine platform (e.g. ``darwin``)
-    """
-
-    import importlib.metadata
-    import importlib.util
-    import os
-    import sys
-
-    args = []
-    for k in sys.argv:
-        if " " in k:
-            args.append(f"'{k}'")
-        else:
-            args.append(k)
-
-    # current date time, in ISO8610 format
-    datetime = __import__("datetime").datetime.now().astimezone().isoformat()
-
-    # collects dependence information
-    package_name = __package__.split(".")[0]
-    requires = importlib.metadata.requires(package_name) or []
-    dependence_names = [re.split(r"(\=|~|!|>|<|;|\s)+", k)[0] for k in requires]
-    dependencies = {
-        k: importlib.metadata.version(k)  # version number as str
-        for k in dependence_names
-        if importlib.util.find_spec(k) is not None  # if is installed
-    }
-
-    # checks if the current version corresponds to a dirty (uncommitted) change
-    # set, issues a warning to the user
-    current_version = importlib.metadata.version(package_name)
-    try:
-        import versioningit
-
-        actual_version = versioningit.get_version(".", config={})
-        if current_version != actual_version:
-            logger.warning(
-                f"Version mismatch between current version set "
-                f"({current_version}) and actual version returned by "
-                f"versioningit ({actual_version}).  This typically happens "
-                f"when you commit changes locally and do not re-install the "
-                f"package. Run `pip install -e .` or equivalent to fix this.",
-            )
-    except Exception as e:
-        # not in a git repo?
-        logger.debug(f"Error {e}")
-        pass
-
-    return {
-        "datetime": datetime,
-        "package-name": __package__.split(".")[0],
-        "package-version": current_version,
-        "dependencies": dependencies,
-        "user": __import__("getpass").getuser(),
-        "conda-env": os.environ.get("CONDA_DEFAULT_ENV", ""),
-        "path": os.path.realpath(os.curdir),
-        "command-line": " ".join(args),
-        "hostname": __import__("platform").node(),
-        "platform": sys.platform,
-    }
-
-
-def save_json_with_backup(path: pathlib.Path, data: dict | list) -> None:
-    """Save a dictionary into a JSON file with path checking and backup.
-
-    This function will save a dictionary into a JSON file.  It will check to
-    the existence of the directory leading to the file and create it if
-    necessary.  If the file already exists on the destination folder, it is
-    backed-up before a new file is created with the new contents.
-
-    Parameters
-    ----------
-    path
-        The full path where to save the JSON data.
-    data
-        The data to save on the JSON file.
-    """
-
-    logger.info(f"Writing run metadata at `{path}`...")
-
-    path.parent.mkdir(parents=True, exist_ok=True)
-    if path.exists():
-        backup = path.parent / (path.name + "~")
-        shutil.copy(path, backup)
-
-    with path.open("w") as f:
-        json.dump(data, f, indent=2)
diff --git a/src/mednet/libs/segmentation/scripts/view.py b/src/mednet/libs/segmentation/scripts/view.py
index e0cb9706d3e8d1e646a2f329210efd3fd8a98a3e..79fcb468abde40e8253bc1209663f62524630889 100644
--- a/src/mednet/libs/segmentation/scripts/view.py
+++ b/src/mednet/libs/segmentation/scripts/view.py
@@ -24,7 +24,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     epilog="""Examples:
 
 \b
-  1. Runs evaluation on an existing dataset configuration:
+  1. Runs view on an existing dataset configuration:
 
      .. code:: sh
 
@@ -146,8 +146,8 @@ def view(
     )
     from mednet.libs.segmentation.engine.viewer import view
 
-    evaluation_filename = "evaluation.json"
-    evaluation_file = output_folder / evaluation_filename
+    view_filename = "view.json"
+    view_file = output_folder / view_filename
 
     with predictions.open("r") as f:
         predict_data = json.load(f)
@@ -164,7 +164,7 @@ def view(
         ),
     )
     json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
-    save_json_with_backup(evaluation_file.with_suffix(".meta.json"), json_data)
+    save_json_with_backup(view_file.with_suffix(".meta.json"), json_data)
 
     threshold = validate_threshold(threshold, predict_data)
     threshold_list = numpy.arange(
@@ -203,5 +203,6 @@ def view(
                 alpha=alpha,
             )
             dest = (output_folder / sample[1]).with_suffix(".png")
+            dest.parent.mkdir(parents=True, exist_ok=True)
             tqdm.tqdm.write(f"{sample[1]} -> {dest}")
             image.save(dest)