diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index bc8b3a9ddb192e432a5345c96b3847379e58aa9d..1aa1b26867f8bcdc4400ce1a9bc8948be07a5afe 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -7,8 +7,6 @@ import torch import torch.nn as nn import torchvision.models as models -from .normalizer import TorchVisionNormalizer - class Densenet(pl.LightningModule): """Densenet module. @@ -31,7 +29,7 @@ class Densenet(pl.LightningModule): self.name = "Densenet" - self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels) + self.normalizer = None # Load pretrained model weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT @@ -55,16 +53,13 @@ class Densenet(pl.LightningModule): imagenet weights, during contruction). """ if self.pretrained: - from .normalizer import TorchVisionNormalizer + from .normalizer import make_imagenet_normalizer - self.normalizer = TorchVisionNormalizer( - torch.Tensor([0.485, 0.456, 0.406]), - torch.Tensor([0.229, 0.224, 0.225]), - ) + self.normalizer = make_imagenet_normalizer() else: - from .normalizer import get_znorm_normalizer + from .normalizer import make_z_normalizer - self.normalizer = get_znorm_normalizer(dataloader) + self.normalizer = make_z_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[1] diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py index b9ba7eb3d81de5c61349f5d4b0cbcf9d8fee9d7f..2cc4b956f17ee1e690031f42e891ef726b548e2a 100644 --- a/src/ptbench/models/normalizer.py +++ b/src/ptbench/models/normalizer.py @@ -2,52 +2,23 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""A network model that prefixes a z-normalization step to any other module.""" +"""A network model that prefixes a subtract/divide step to any other module.""" import torch import torch.nn import torch.utils.data +import torchvision.transforms -class TorchVisionNormalizer(torch.nn.Module): - """A simple normalizer that applies the standard torchvision normalization. +def make_z_normalizer( + dataloader: torch.utils.data.DataLoader, +) -> torchvision.transforms.Normalize: + """Computes mean and standard deviation from a dataloader. - This module does not learn. + This function will input a dataloader, and compute the mean and standard + deviation by image channel. It will work for both monochromatic, and color + inputs with 2, 3 or more color planes. - Parameters - ---------- - - nb_channels : :py:class:`int`, Optional - Number of images channels fed to the model - """ - - def __init__(self, subtract: torch.Tensor, divide: torch.Tensor): - super().__init__() - if len(subtract) != len(divide): - raise ValueError( - "Lengths of 'subtract' and 'divide' tensors should be the same." - ) - if len(subtract) not in (1, 3): - raise ValueError( - "Length of 'subtract' tensor should be either 1 or 3, depending on the number of color channels." - ) - - subtract = torch.as_tensor(subtract)[None, :, None, None] - divide = torch.as_tensor(divide)[None, :, None, None] - self.register_buffer("subtract", subtract) - self.register_buffer("divide", divide) - self.name = "torchvision-normalizer" - - def forward(self, inputs: torch.Tensor): - """inputs shape [batches, planes, height, width]""" - return inputs.sub(self.subtract).div(self.divide) - - -def get_znorm_normalizer( - dataloader: torch.utils.data.DataLoader, -) -> TorchVisionNormalizer: - """Returns a normalizer with the mean and std computed from a dataloader's - unaugmented training set. Parameters ---------- @@ -55,15 +26,22 @@ def get_znorm_normalizer( dataloader: A torch Dataloader from which to compute the mean and std + Returns ------- - An initialized TorchVisionNormalizer + An initialized normalizer """ - mean = 0.0 - var = 0.0 + # Peek the number of channels of batches in the data loader + batch = next(iter(dataloader)) + channels = batch[0].shape[1] + + # Initialises accumulators + mean = torch.zeros(channels, dtype=batch[0].dtype) + var = torch.zeros(channels, dtype=batch[0].dtype) num_images = 0 + # Evaluates mean and standard deviation for batch in dataloader: data = batch[0] data = data.view(data.size(0), data.size(1), -1) @@ -76,5 +54,21 @@ def get_znorm_normalizer( var /= num_images std = torch.sqrt(var) - normalizer = TorchVisionNormalizer(mean, std) - return normalizer + return torchvision.transforms.Normalize(mean, std) + + +def make_imagenet_normalizer() -> torchvision.transforms.Normalize: + """Returns the stock ImageNet normalisation weights from torchvision. + + The weights are wrapped in a torch module. This normalizer only works for + **RGB (color) images**. + + + Returns + ------- + An initialized normalizer + """ + + return torchvision.transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index cca494f141ca625caa0d4872802aa266f6e70a0f..61105dbe3877aab1e3ad09af8c3c406a3a9273ec 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -79,12 +79,7 @@ class PASA(pl.LightningModule): self.dense = nn.Linear(80, 1) # Fully connected layer def forward(self, x): - if self.normalizer is None: - raise TypeError( - "The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model." - ) - - x = self.normalizer(x) + x = self.normalizer(x) # type: ignore # First convolution block _x = x @@ -140,11 +135,11 @@ class PASA(pl.LightningModule): dataloader: A torch Dataloader from which to compute the mean and std """ - from .normalizer import get_znorm_normalizer + from .normalizer import make_z_normalizer - self.normalizer = get_znorm_normalizer(dataloader) + self.normalizer = make_z_normalizer(dataloader) - def training_step(self, batch, batch_idx): + def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"]