diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/classification/models/normalizer.py
similarity index 100%
rename from src/mednet/libs/common/models/normalizer.py
rename to src/mednet/libs/classification/models/normalizer.py
diff --git a/src/mednet/libs/classification/models/pasa.py b/src/mednet/libs/classification/models/pasa.py
index 0b5c24da02b73f1328651612e216cb731b90bbf7..313e1b54eaddc2d455c20d45f5ff6fb0d183ca1a 100644
--- a/src/mednet/libs/classification/models/pasa.py
+++ b/src/mednet/libs/classification/models/pasa.py
@@ -192,6 +192,23 @@ class Pasa(Model):
 
         # x = F.log_softmax(x, dim=1) # 0 is batch size
 
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the input normalizer for the current model.
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+        """
+
+        from .normalizer import make_z_normalizer
+
+        logger.info(
+            f"Uninitialised {self.name} model - "
+            f"computing z-norm factors from train dataloader.",
+        )
+        self.normalizer = make_z_normalizer(dataloader)
+
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["target"]
diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py
index 4dc98b1da5e049bf541f751256428c912e8b8a32..cc9d5d84dee0a571d7809e4dfa1832f429c68de3 100644
--- a/src/mednet/libs/common/models/model.py
+++ b/src/mednet/libs/common/models/model.py
@@ -130,21 +130,7 @@ class Model(pl.LightningModule):
         self.normalizer = checkpoint["normalizer"]
 
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
-        """Initialize the input normalizer for the current model.
-
-        Parameters
-        ----------
-        dataloader
-            A torch Dataloader from which to compute the mean and std.
-        """
-
-        from .normalizer import make_z_normalizer
-
-        logger.info(
-            f"Uninitialised {self.name} model - "
-            f"computing z-norm factors from train dataloader.",
-        )
-        self.normalizer = make_z_normalizer(dataloader)
+        raise NotImplementedError
 
     def training_step(self, batch, _):
         raise NotImplementedError
diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
index 06014c81efc338b24b3a30c06e614d5b36281f1c..176293431877ed2dcf6e124c90faf4de0f7944a6 100644
--- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
@@ -57,11 +57,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
             PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
         )
 
-        tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Image(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(crop_image_to_mask(image, mask))
+        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
         mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
 
-        return tensor, dict(target=target, mask=mask, name=sample[0])  # type: ignore[arg-type]
+        return dict(image=image, target=target, mask=mask), dict(name=sample[0])  # type: ignore[arg-type]
 
 
 class DataModule(CachingDataModule):
diff --git a/src/mednet/libs/segmentation/models/driu.py b/src/mednet/libs/segmentation/models/driu.py
index eac0afc03965519ac1af65127ac5565f31151e2a..f3a2e0d2d38866f670f051e6b1ea685ab2c3fd1f 100644
--- a/src/mednet/libs/segmentation/models/driu.py
+++ b/src/mednet/libs/segmentation/models/driu.py
@@ -149,7 +149,7 @@ class DRIU(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/driu_bn.py b/src/mednet/libs/segmentation/models/driu_bn.py
index 7804ed06747d22cd947f4a1ba46d657d9e36c64d..3bb93ba9ee9221938a65af1c2d6fd62a1104d55d 100644
--- a/src/mednet/libs/segmentation/models/driu_bn.py
+++ b/src/mednet/libs/segmentation/models/driu_bn.py
@@ -152,7 +152,7 @@ class DRIUBN(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/driu_od.py b/src/mednet/libs/segmentation/models/driu_od.py
index be7416c7f34e89a6c826f4b87f0dbafe5c4e1a99..98e596236bddf19fcc1789f136f1e162ed76ecf3 100644
--- a/src/mednet/libs/segmentation/models/driu_od.py
+++ b/src/mednet/libs/segmentation/models/driu_od.py
@@ -134,7 +134,7 @@ class DRIUOD(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/driu_pix.py b/src/mednet/libs/segmentation/models/driu_pix.py
index 23218a577a34367b090a51dbd7340366f67dbe5c..6846da5a8f2700ecc4f2370cd485875016c1749b 100644
--- a/src/mednet/libs/segmentation/models/driu_pix.py
+++ b/src/mednet/libs/segmentation/models/driu_pix.py
@@ -138,7 +138,7 @@ class DRIUPix(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/hed.py b/src/mednet/libs/segmentation/models/hed.py
index 8771f55813aab08ccbe853d151f6f28c26ba816b..7e0b770513d0f8d5d94515827160f029cbf3346c 100644
--- a/src/mednet/libs/segmentation/models/hed.py
+++ b/src/mednet/libs/segmentation/models/hed.py
@@ -153,7 +153,7 @@ class HED(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index 90507e4dcd513e9f0a8d12a368f08a010c9ff77c..28bdf498fd74705546bf39a57835b378e4144f4e 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -15,6 +15,7 @@ guide segmentation.
 Reference: [GALDRAN-2020]_
 """
 
+import logging
 import typing
 
 import torch
@@ -23,6 +24,8 @@ from mednet.libs.common.data.typing import TransformSequence
 from mednet.libs.common.models.model import Model
 from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
 
+logger = logging.getLogger("mednet")
+
 
 def _conv1x1(in_planes, out_planes, stride=1):
     return torch.nn.Conv2d(
@@ -338,6 +341,23 @@ class LittleWNet(Model):
             shortcut=True,
         )
 
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the input normalizer for the current model.
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+        """
+
+        from .normalizer import make_z_normalizer
+
+        logger.info(
+            f"Uninitialised {self.name} model - "
+            f"computing z-norm factors from train dataloader.",
+        )
+        self.normalizer = make_z_normalizer(dataloader)
+
     def forward(self, x):
         xn = self.normalizer(x)
         x1 = self.unet1(xn)
diff --git a/src/mednet/libs/segmentation/models/m2unet.py b/src/mednet/libs/segmentation/models/m2unet.py
index 60f97967ef96b25c87afee5570cc81a8d7cab90c..b371540945d8e8b3c481c84abf7ea6c3309995fa 100644
--- a/src/mednet/libs/segmentation/models/m2unet.py
+++ b/src/mednet/libs/segmentation/models/m2unet.py
@@ -201,7 +201,7 @@ class M2UNET(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "
diff --git a/src/mednet/libs/segmentation/models/normalizer.py b/src/mednet/libs/segmentation/models/normalizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..df630bd1d61c9b7e10a951a80e8f560248f9a477
--- /dev/null
+++ b/src/mednet/libs/segmentation/models/normalizer.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""Functions to compute normalisation factors based on dataloaders."""
+
+import logging
+
+import torch
+import torch.nn
+import torch.utils.data
+import torchvision.transforms
+import tqdm
+
+logger = logging.getLogger("mednet")
+
+
+def make_z_normalizer(
+    dataloader: torch.utils.data.DataLoader,
+) -> torchvision.transforms.Normalize:
+    """Compute mean and standard deviation from a dataloader.
+
+    This function will input a dataloader, and compute the mean and standard
+    deviation by image channel.  It will work for both monochromatic, and color
+    inputs with 2, 3 or more color planes.
+
+    Parameters
+    ----------
+    dataloader
+        A torch Dataloader from which to compute the mean and std.
+
+    Returns
+    -------
+        An initialized normalizer.
+    """
+
+    # Peek the number of channels of batches in the data loader
+    batch = next(iter(dataloader))
+    channels = batch[0]["image"].shape[1]
+
+    # Initialises accumulators
+    mean = torch.zeros(channels, dtype=batch[0]["image"].dtype)
+    var = torch.zeros(channels, dtype=batch[0]["image"].dtype)
+    num_images = 0
+
+    # Evaluates mean and standard deviation
+    for batch in tqdm.tqdm(dataloader, unit="batch"):
+        data = batch[0]
+        data = data.view(data.size(0), data.size(1), -1)
+
+        num_images += data.size(0)
+        mean += data.mean(2).sum(0)
+        var += data.var(2).sum(0)
+
+    mean /= num_images
+    var /= num_images
+    std = torch.sqrt(var)
+
+    return torchvision.transforms.Normalize(mean, std)
+
+
+def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
+    """Return the stock ImageNet normalisation weights from torchvision.
+
+    The weights are wrapped in a torch module.  This normalizer only works for
+    **RGB (color) images**.
+
+    Returns
+    -------
+        An initialized normalizer.
+    """
+
+    return torchvision.transforms.Normalize(
+        (0.485, 0.456, 0.406),
+        (0.229, 0.224, 0.225),
+    )
diff --git a/src/mednet/libs/segmentation/models/unet.py b/src/mednet/libs/segmentation/models/unet.py
index 7a3ee9bbccde07ef897e1cf03ad3e4efeb44cf17..98578d1d437bbe53fbab48f31764a1580d5ad5eb 100644
--- a/src/mednet/libs/segmentation/models/unet.py
+++ b/src/mednet/libs/segmentation/models/unet.py
@@ -142,7 +142,7 @@ class Unet(Model):
             Will not be used if the model is pretrained.
         """
         if self.pretrained:
-            from mednet.libs.common.models.normalizer import make_imagenet_normalizer
+            from .normalizer import make_imagenet_normalizer
 
             logger.warning(
                 f"ImageNet pre-trained {self.name} model - NOT "