Skip to content
Snippets Groups Projects
Commit 9e721cb3 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Remove runtime checks; Use torchvision normaliser instead of our own

parent 3da537d0
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75473 failed
......@@ -7,8 +7,6 @@ import torch
import torch.nn as nn
import torchvision.models as models
from .normalizer import TorchVisionNormalizer
class Densenet(pl.LightningModule):
"""Densenet module.
......@@ -31,7 +29,7 @@ class Densenet(pl.LightningModule):
self.name = "Densenet"
self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
self.normalizer = None
# Load pretrained model
weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
......@@ -55,16 +53,13 @@ class Densenet(pl.LightningModule):
imagenet weights, during contruction).
"""
if self.pretrained:
from .normalizer import TorchVisionNormalizer
from .normalizer import make_imagenet_normalizer
self.normalizer = TorchVisionNormalizer(
torch.Tensor([0.485, 0.456, 0.406]),
torch.Tensor([0.229, 0.224, 0.225]),
)
self.normalizer = make_imagenet_normalizer()
else:
from .normalizer import get_znorm_normalizer
from .normalizer import make_z_normalizer
self.normalizer = get_znorm_normalizer(dataloader)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[1]
......
......@@ -2,52 +2,23 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""A network model that prefixes a z-normalization step to any other module."""
"""A network model that prefixes a subtract/divide step to any other module."""
import torch
import torch.nn
import torch.utils.data
import torchvision.transforms
class TorchVisionNormalizer(torch.nn.Module):
"""A simple normalizer that applies the standard torchvision normalization.
def make_z_normalizer(
dataloader: torch.utils.data.DataLoader,
) -> torchvision.transforms.Normalize:
"""Computes mean and standard deviation from a dataloader.
This module does not learn.
This function will input a dataloader, and compute the mean and standard
deviation by image channel. It will work for both monochromatic, and color
inputs with 2, 3 or more color planes.
Parameters
----------
nb_channels : :py:class:`int`, Optional
Number of images channels fed to the model
"""
def __init__(self, subtract: torch.Tensor, divide: torch.Tensor):
super().__init__()
if len(subtract) != len(divide):
raise ValueError(
"Lengths of 'subtract' and 'divide' tensors should be the same."
)
if len(subtract) not in (1, 3):
raise ValueError(
"Length of 'subtract' tensor should be either 1 or 3, depending on the number of color channels."
)
subtract = torch.as_tensor(subtract)[None, :, None, None]
divide = torch.as_tensor(divide)[None, :, None, None]
self.register_buffer("subtract", subtract)
self.register_buffer("divide", divide)
self.name = "torchvision-normalizer"
def forward(self, inputs: torch.Tensor):
"""inputs shape [batches, planes, height, width]"""
return inputs.sub(self.subtract).div(self.divide)
def get_znorm_normalizer(
dataloader: torch.utils.data.DataLoader,
) -> TorchVisionNormalizer:
"""Returns a normalizer with the mean and std computed from a dataloader's
unaugmented training set.
Parameters
----------
......@@ -55,15 +26,22 @@ def get_znorm_normalizer(
dataloader:
A torch Dataloader from which to compute the mean and std
Returns
-------
An initialized TorchVisionNormalizer
An initialized normalizer
"""
mean = 0.0
var = 0.0
# Peek the number of channels of batches in the data loader
batch = next(iter(dataloader))
channels = batch[0].shape[1]
# Initialises accumulators
mean = torch.zeros(channels, dtype=batch[0].dtype)
var = torch.zeros(channels, dtype=batch[0].dtype)
num_images = 0
# Evaluates mean and standard deviation
for batch in dataloader:
data = batch[0]
data = data.view(data.size(0), data.size(1), -1)
......@@ -76,5 +54,21 @@ def get_znorm_normalizer(
var /= num_images
std = torch.sqrt(var)
normalizer = TorchVisionNormalizer(mean, std)
return normalizer
return torchvision.transforms.Normalize(mean, std)
def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
"""Returns the stock ImageNet normalisation weights from torchvision.
The weights are wrapped in a torch module. This normalizer only works for
**RGB (color) images**.
Returns
-------
An initialized normalizer
"""
return torchvision.transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
)
......@@ -79,12 +79,7 @@ class PASA(pl.LightningModule):
self.dense = nn.Linear(80, 1) # Fully connected layer
def forward(self, x):
if self.normalizer is None:
raise TypeError(
"The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model."
)
x = self.normalizer(x)
x = self.normalizer(x) # type: ignore
# First convolution block
_x = x
......@@ -140,11 +135,11 @@ class PASA(pl.LightningModule):
dataloader:
A torch Dataloader from which to compute the mean and std
"""
from .normalizer import get_znorm_normalizer
from .normalizer import make_z_normalizer
self.normalizer = get_znorm_normalizer(dataloader)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx):
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment