From 9e721cb3d9d38ee998e27cd86a5ba63258a28d2b Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Sat, 1 Jul 2023 19:08:15 +0200 Subject: [PATCH] [models] Remove runtime checks; Use torchvision normaliser instead of our own --- src/ptbench/models/densenet.py | 15 ++---- src/ptbench/models/normalizer.py | 80 +++++++++++++++----------------- src/ptbench/models/pasa.py | 13 ++---- 3 files changed, 46 insertions(+), 62 deletions(-) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index bc8b3a9d..1aa1b268 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -7,8 +7,6 @@ import torch import torch.nn as nn import torchvision.models as models -from .normalizer import TorchVisionNormalizer - class Densenet(pl.LightningModule): """Densenet module. @@ -31,7 +29,7 @@ class Densenet(pl.LightningModule): self.name = "Densenet" - self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels) + self.normalizer = None # Load pretrained model weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT @@ -55,16 +53,13 @@ class Densenet(pl.LightningModule): imagenet weights, during contruction). """ if self.pretrained: - from .normalizer import TorchVisionNormalizer + from .normalizer import make_imagenet_normalizer - self.normalizer = TorchVisionNormalizer( - torch.Tensor([0.485, 0.456, 0.406]), - torch.Tensor([0.229, 0.224, 0.225]), - ) + self.normalizer = make_imagenet_normalizer() else: - from .normalizer import get_znorm_normalizer + from .normalizer import make_z_normalizer - self.normalizer = get_znorm_normalizer(dataloader) + self.normalizer = make_z_normalizer(dataloader) def training_step(self, batch, batch_idx): images = batch[1] diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py index b9ba7eb3..2cc4b956 100644 --- a/src/ptbench/models/normalizer.py +++ b/src/ptbench/models/normalizer.py @@ -2,52 +2,23 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""A network model that prefixes a z-normalization step to any other module.""" +"""A network model that prefixes a subtract/divide step to any other module.""" import torch import torch.nn import torch.utils.data +import torchvision.transforms -class TorchVisionNormalizer(torch.nn.Module): - """A simple normalizer that applies the standard torchvision normalization. +def make_z_normalizer( + dataloader: torch.utils.data.DataLoader, +) -> torchvision.transforms.Normalize: + """Computes mean and standard deviation from a dataloader. - This module does not learn. + This function will input a dataloader, and compute the mean and standard + deviation by image channel. It will work for both monochromatic, and color + inputs with 2, 3 or more color planes. - Parameters - ---------- - - nb_channels : :py:class:`int`, Optional - Number of images channels fed to the model - """ - - def __init__(self, subtract: torch.Tensor, divide: torch.Tensor): - super().__init__() - if len(subtract) != len(divide): - raise ValueError( - "Lengths of 'subtract' and 'divide' tensors should be the same." - ) - if len(subtract) not in (1, 3): - raise ValueError( - "Length of 'subtract' tensor should be either 1 or 3, depending on the number of color channels." - ) - - subtract = torch.as_tensor(subtract)[None, :, None, None] - divide = torch.as_tensor(divide)[None, :, None, None] - self.register_buffer("subtract", subtract) - self.register_buffer("divide", divide) - self.name = "torchvision-normalizer" - - def forward(self, inputs: torch.Tensor): - """inputs shape [batches, planes, height, width]""" - return inputs.sub(self.subtract).div(self.divide) - - -def get_znorm_normalizer( - dataloader: torch.utils.data.DataLoader, -) -> TorchVisionNormalizer: - """Returns a normalizer with the mean and std computed from a dataloader's - unaugmented training set. Parameters ---------- @@ -55,15 +26,22 @@ def get_znorm_normalizer( dataloader: A torch Dataloader from which to compute the mean and std + Returns ------- - An initialized TorchVisionNormalizer + An initialized normalizer """ - mean = 0.0 - var = 0.0 + # Peek the number of channels of batches in the data loader + batch = next(iter(dataloader)) + channels = batch[0].shape[1] + + # Initialises accumulators + mean = torch.zeros(channels, dtype=batch[0].dtype) + var = torch.zeros(channels, dtype=batch[0].dtype) num_images = 0 + # Evaluates mean and standard deviation for batch in dataloader: data = batch[0] data = data.view(data.size(0), data.size(1), -1) @@ -76,5 +54,21 @@ def get_znorm_normalizer( var /= num_images std = torch.sqrt(var) - normalizer = TorchVisionNormalizer(mean, std) - return normalizer + return torchvision.transforms.Normalize(mean, std) + + +def make_imagenet_normalizer() -> torchvision.transforms.Normalize: + """Returns the stock ImageNet normalisation weights from torchvision. + + The weights are wrapped in a torch module. This normalizer only works for + **RGB (color) images**. + + + Returns + ------- + An initialized normalizer + """ + + return torchvision.transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index cca494f1..61105dbe 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -79,12 +79,7 @@ class PASA(pl.LightningModule): self.dense = nn.Linear(80, 1) # Fully connected layer def forward(self, x): - if self.normalizer is None: - raise TypeError( - "The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model." - ) - - x = self.normalizer(x) + x = self.normalizer(x) # type: ignore # First convolution block _x = x @@ -140,11 +135,11 @@ class PASA(pl.LightningModule): dataloader: A torch Dataloader from which to compute the mean and std """ - from .normalizer import get_znorm_normalizer + from .normalizer import make_z_normalizer - self.normalizer = get_znorm_normalizer(dataloader) + self.normalizer = make_z_normalizer(dataloader) - def training_step(self, batch, batch_idx): + def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] -- GitLab