diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index bc8b3a9ddb192e432a5345c96b3847379e58aa9d..1aa1b26867f8bcdc4400ce1a9bc8948be07a5afe 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -7,8 +7,6 @@ import torch
 import torch.nn as nn
 import torchvision.models as models
 
-from .normalizer import TorchVisionNormalizer
-
 
 class Densenet(pl.LightningModule):
     """Densenet module.
@@ -31,7 +29,7 @@ class Densenet(pl.LightningModule):
 
         self.name = "Densenet"
 
-        self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
+        self.normalizer = None
 
         # Load pretrained model
         weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
@@ -55,16 +53,13 @@ class Densenet(pl.LightningModule):
         imagenet weights, during contruction).
         """
         if self.pretrained:
-            from .normalizer import TorchVisionNormalizer
+            from .normalizer import make_imagenet_normalizer
 
-            self.normalizer = TorchVisionNormalizer(
-                torch.Tensor([0.485, 0.456, 0.406]),
-                torch.Tensor([0.229, 0.224, 0.225]),
-            )
+            self.normalizer = make_imagenet_normalizer()
         else:
-            from .normalizer import get_znorm_normalizer
+            from .normalizer import make_z_normalizer
 
-            self.normalizer = get_znorm_normalizer(dataloader)
+            self.normalizer = make_z_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
         images = batch[1]
diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py
index b9ba7eb3d81de5c61349f5d4b0cbcf9d8fee9d7f..2cc4b956f17ee1e690031f42e891ef726b548e2a 100644
--- a/src/ptbench/models/normalizer.py
+++ b/src/ptbench/models/normalizer.py
@@ -2,52 +2,23 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""A network model that prefixes a z-normalization step to any other module."""
+"""A network model that prefixes a subtract/divide step to any other module."""
 
 import torch
 import torch.nn
 import torch.utils.data
+import torchvision.transforms
 
 
-class TorchVisionNormalizer(torch.nn.Module):
-    """A simple normalizer that applies the standard torchvision normalization.
+def make_z_normalizer(
+    dataloader: torch.utils.data.DataLoader,
+) -> torchvision.transforms.Normalize:
+    """Computes mean and standard deviation from a dataloader.
 
-    This module does not learn.
+    This function will input a dataloader, and compute the mean and standard
+    deviation by image channel.  It will work for both monochromatic, and color
+    inputs with 2, 3 or more color planes.
 
-    Parameters
-    ----------
-
-    nb_channels : :py:class:`int`, Optional
-        Number of images channels fed to the model
-    """
-
-    def __init__(self, subtract: torch.Tensor, divide: torch.Tensor):
-        super().__init__()
-        if len(subtract) != len(divide):
-            raise ValueError(
-                "Lengths of 'subtract' and 'divide' tensors should be the same."
-            )
-        if len(subtract) not in (1, 3):
-            raise ValueError(
-                "Length of 'subtract' tensor should be either 1 or 3, depending on the number of color channels."
-            )
-
-        subtract = torch.as_tensor(subtract)[None, :, None, None]
-        divide = torch.as_tensor(divide)[None, :, None, None]
-        self.register_buffer("subtract", subtract)
-        self.register_buffer("divide", divide)
-        self.name = "torchvision-normalizer"
-
-    def forward(self, inputs: torch.Tensor):
-        """inputs shape [batches, planes, height, width]"""
-        return inputs.sub(self.subtract).div(self.divide)
-
-
-def get_znorm_normalizer(
-    dataloader: torch.utils.data.DataLoader,
-) -> TorchVisionNormalizer:
-    """Returns a normalizer with the mean and std computed from a dataloader's
-    unaugmented training set.
 
     Parameters
     ----------
@@ -55,15 +26,22 @@ def get_znorm_normalizer(
     dataloader:
         A torch Dataloader from which to compute the mean and std
 
+
     Returns
     -------
-        An initialized TorchVisionNormalizer
+        An initialized normalizer
     """
 
-    mean = 0.0
-    var = 0.0
+    # Peek the number of channels of batches in the data loader
+    batch = next(iter(dataloader))
+    channels = batch[0].shape[1]
+
+    # Initialises accumulators
+    mean = torch.zeros(channels, dtype=batch[0].dtype)
+    var = torch.zeros(channels, dtype=batch[0].dtype)
     num_images = 0
 
+    # Evaluates mean and standard deviation
     for batch in dataloader:
         data = batch[0]
         data = data.view(data.size(0), data.size(1), -1)
@@ -76,5 +54,21 @@ def get_znorm_normalizer(
     var /= num_images
     std = torch.sqrt(var)
 
-    normalizer = TorchVisionNormalizer(mean, std)
-    return normalizer
+    return torchvision.transforms.Normalize(mean, std)
+
+
+def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
+    """Returns the stock ImageNet normalisation weights from torchvision.
+
+    The weights are wrapped in a torch module.  This normalizer only works for
+    **RGB (color) images**.
+
+
+    Returns
+    -------
+        An initialized normalizer
+    """
+
+    return torchvision.transforms.Normalize(
+        (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
+    )
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index cca494f141ca625caa0d4872802aa266f6e70a0f..61105dbe3877aab1e3ad09af8c3c406a3a9273ec 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -79,12 +79,7 @@ class PASA(pl.LightningModule):
         self.dense = nn.Linear(80, 1)  # Fully connected layer
 
     def forward(self, x):
-        if self.normalizer is None:
-            raise TypeError(
-                "The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model."
-            )
-
-        x = self.normalizer(x)
+        x = self.normalizer(x)  # type: ignore
 
         # First convolution block
         _x = x
@@ -140,11 +135,11 @@ class PASA(pl.LightningModule):
         dataloader:
             A torch Dataloader from which to compute the mean and std
         """
-        from .normalizer import get_znorm_normalizer
+        from .normalizer import make_z_normalizer
 
-        self.normalizer = get_znorm_normalizer(dataloader)
+        self.normalizer = make_z_normalizer(dataloader)
 
-    def training_step(self, batch, batch_idx):
+    def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]