diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 6670a3099bd4f855f89eba0001380421a08e176f..633ebd62c476b7674943a4cfb97f79109eae1e15 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -95,7 +95,11 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
     ):
         self.split = split
         self.raw_data_loader = raw_data_loader
-        self.transform = torchvision.transforms.Compose(*transforms)
+        # Cannot unpack empty list
+        if len(transforms) > 0:
+            self.transform = torchvision.transforms.Compose([*transforms])
+        else:
+            self.transform = torchvision.transforms.Compose([])
 
     def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
         tensor, metadata = self.raw_data_loader(self.split[key])
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 52bd27f07b16131deea9e3d0aa11c8c1cf294f88..bc8b3a9ddb192e432a5345c96b3847379e58aa9d 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -57,7 +57,10 @@ class Densenet(pl.LightningModule):
         if self.pretrained:
             from .normalizer import TorchVisionNormalizer
 
-            self.normalizer = TorchVisionNormalizer(..., ...)
+            self.normalizer = TorchVisionNormalizer(
+                torch.Tensor([0.485, 0.456, 0.406]),
+                torch.Tensor([0.229, 0.224, 0.225]),
+            )
         else:
             from .normalizer import get_znorm_normalizer
 
diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py
index 10320ab1a7ef2aa3bd2a7b1b43566a21da865150..b9ba7eb3d81de5c61349f5d4b0cbcf9d8fee9d7f 100644
--- a/src/ptbench/models/normalizer.py
+++ b/src/ptbench/models/normalizer.py
@@ -23,12 +23,17 @@ class TorchVisionNormalizer(torch.nn.Module):
 
     def __init__(self, subtract: torch.Tensor, divide: torch.Tensor):
         super().__init__()
-        assert len(subtract) == len(divide), "TODO"
-        assert len(subtract) in (1, 3), "TODO"
-        self.subtract = subtract
-        self.divided = divide
-        subtract = torch.zeros(len(subtract.shape))[None, :, None, None]
-        divide = torch.ones(len(divide.shape))[None, :, None, None]
+        if len(subtract) != len(divide):
+            raise ValueError(
+                "Lengths of 'subtract' and 'divide' tensors should be the same."
+            )
+        if len(subtract) not in (1, 3):
+            raise ValueError(
+                "Length of 'subtract' tensor should be either 1 or 3, depending on the number of color channels."
+            )
+
+        subtract = torch.as_tensor(subtract)[None, :, None, None]
+        divide = torch.as_tensor(divide)[None, :, None, None]
         self.register_buffer("subtract", subtract)
         self.register_buffer("divide", divide)
         self.name = "torchvision-normalizer"
@@ -41,13 +46,35 @@ class TorchVisionNormalizer(torch.nn.Module):
 def get_znorm_normalizer(
     dataloader: torch.utils.data.DataLoader,
 ) -> TorchVisionNormalizer:
-    # TODO: Fix this function to use unaugmented training set
-    # TODO: This function is only applicable IFF we are not fine-tuning (ie.
-    #       model does not re-use weights from imagenet training!)
-    # TODO: Add type hints
-    # TODO: Add documentation
+    """Returns a normalizer with the mean and std computed from a dataloader's
+    unaugmented training set.
+
+    Parameters
+    ----------
+
+    dataloader:
+        A torch Dataloader from which to compute the mean and std
+
+    Returns
+    -------
+        An initialized TorchVisionNormalizer
+    """
+
+    mean = 0.0
+    var = 0.0
+    num_images = 0
+
+    for batch in dataloader:
+        data = batch[0]
+        data = data.view(data.size(0), data.size(1), -1)
+
+        num_images += data.size(0)
+        mean += data.mean(2).sum(0)
+        var += data.var(2).sum(0)
 
-    # 1 extract mean/std from dataloader
+    mean /= num_images
+    var /= num_images
+    std = torch.sqrt(var)
 
-    # 2 return TorchVisionNormalizer(mean, std)
-    pass
+    normalizer = TorchVisionNormalizer(mean, std)
+    return normalizer
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index d14ea5d4f745a4f237c3cc96337d862013318188..cca494f141ca625caa0d4872802aa266f6e70a0f 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -8,8 +8,6 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.utils.data
 
-from .normalizer import TorchVisionNormalizer
-
 
 class PASA(pl.LightningModule):
     """PASA module.
@@ -30,7 +28,7 @@ class PASA(pl.LightningModule):
 
         self.name = "pasa"
 
-        self.normalizer = TorchVisionNormalizer(nb_channels=1)
+        self.normalizer = None
 
         # First convolution block
         self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
@@ -81,6 +79,11 @@ class PASA(pl.LightningModule):
         self.dense = nn.Linear(80, 1)  # Fully connected layer
 
     def forward(self, x):
+        if self.normalizer is None:
+            raise TypeError(
+                "The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model."
+            )
+
         x = self.normalizer(x)
 
         # First convolution block
@@ -129,14 +132,21 @@ class PASA(pl.LightningModule):
         return x
 
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
-        """TODO: Write this function documentation"""
+        """Initializes the normalizer for the current model.
+
+        Parameters
+        ----------
+
+        dataloader:
+            A torch Dataloader from which to compute the mean and std
+        """
         from .normalizer import get_znorm_normalizer
 
         self.normalizer = get_znorm_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
-        images = batch["data"]
-        labels = batch["label"]
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -153,8 +163,8 @@ class PASA(pl.LightningModule):
         return {"loss": training_loss}
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
-        images = batch["data"]
-        labels = batch["label"]
+        images = batch[0]
+        labels = batch[1]["label"]
 
         # Increase label dimension if too low
         # Allows single and multiclass usage
@@ -176,9 +186,9 @@ class PASA(pl.LightningModule):
             return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch["name"]
-        images = batch["data"]
-        labels = batch["label"]
+        images = batch[0]
+        labels = batch[1]["label"]
+        names = batch[1]["names"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 4d2a226b5b0b479b3f84c748d24aebd43d8d6dea..eeb5b8696ca9a72f9828021d37e8420d56fdf1d3 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -263,26 +263,21 @@ def train(
     import torch.cuda
     import torch.nn
 
-    from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss
+    from ..data.dataset import reweight_BCEWithLogitsLoss
     from ..engine.trainer import run
 
     seed_everything(seed)
 
     checkpoint_file = get_checkpoint(output_folder, resume_from)
 
-    datamodule.update_module_properties(
-        batch_size=batch_size,
-        batch_chunk_count=batch_chunk_count,
-        drop_incomplete_batch=drop_incomplete_batch,
-        cache_samples=cache_samples,
-        parallel=parallel,
-    )
+    datamodule.set_chunk_size(batch_size, batch_chunk_count)
+    datamodule.parallel = parallel
 
     datamodule.prepare_data()
     datamodule.setup(stage="fit")
 
     reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid)
-    normalize_data(normalization, model, datamodule)
+    model.set_normalizer(datamodule.unaugmented_train_dataloader())
 
     arguments = {}
     arguments["max_epoch"] = epochs