Skip to content
Snippets Groups Projects
Commit 9e721cb3 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Remove runtime checks; Use torchvision normaliser instead of our own

parent 3da537d0
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75473 failed
...@@ -7,8 +7,6 @@ import torch ...@@ -7,8 +7,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchvision.models as models import torchvision.models as models
from .normalizer import TorchVisionNormalizer
class Densenet(pl.LightningModule): class Densenet(pl.LightningModule):
"""Densenet module. """Densenet module.
...@@ -31,7 +29,7 @@ class Densenet(pl.LightningModule): ...@@ -31,7 +29,7 @@ class Densenet(pl.LightningModule):
self.name = "Densenet" self.name = "Densenet"
self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels) self.normalizer = None
# Load pretrained model # Load pretrained model
weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
...@@ -55,16 +53,13 @@ class Densenet(pl.LightningModule): ...@@ -55,16 +53,13 @@ class Densenet(pl.LightningModule):
imagenet weights, during contruction). imagenet weights, during contruction).
""" """
if self.pretrained: if self.pretrained:
from .normalizer import TorchVisionNormalizer from .normalizer import make_imagenet_normalizer
self.normalizer = TorchVisionNormalizer( self.normalizer = make_imagenet_normalizer()
torch.Tensor([0.485, 0.456, 0.406]),
torch.Tensor([0.229, 0.224, 0.225]),
)
else: else:
from .normalizer import get_znorm_normalizer from .normalizer import make_z_normalizer
self.normalizer = get_znorm_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[1] images = batch[1]
......
...@@ -2,52 +2,23 @@ ...@@ -2,52 +2,23 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""A network model that prefixes a z-normalization step to any other module.""" """A network model that prefixes a subtract/divide step to any other module."""
import torch import torch
import torch.nn import torch.nn
import torch.utils.data import torch.utils.data
import torchvision.transforms
class TorchVisionNormalizer(torch.nn.Module): def make_z_normalizer(
"""A simple normalizer that applies the standard torchvision normalization. dataloader: torch.utils.data.DataLoader,
) -> torchvision.transforms.Normalize:
"""Computes mean and standard deviation from a dataloader.
This module does not learn. This function will input a dataloader, and compute the mean and standard
deviation by image channel. It will work for both monochromatic, and color
inputs with 2, 3 or more color planes.
Parameters
----------
nb_channels : :py:class:`int`, Optional
Number of images channels fed to the model
"""
def __init__(self, subtract: torch.Tensor, divide: torch.Tensor):
super().__init__()
if len(subtract) != len(divide):
raise ValueError(
"Lengths of 'subtract' and 'divide' tensors should be the same."
)
if len(subtract) not in (1, 3):
raise ValueError(
"Length of 'subtract' tensor should be either 1 or 3, depending on the number of color channels."
)
subtract = torch.as_tensor(subtract)[None, :, None, None]
divide = torch.as_tensor(divide)[None, :, None, None]
self.register_buffer("subtract", subtract)
self.register_buffer("divide", divide)
self.name = "torchvision-normalizer"
def forward(self, inputs: torch.Tensor):
"""inputs shape [batches, planes, height, width]"""
return inputs.sub(self.subtract).div(self.divide)
def get_znorm_normalizer(
dataloader: torch.utils.data.DataLoader,
) -> TorchVisionNormalizer:
"""Returns a normalizer with the mean and std computed from a dataloader's
unaugmented training set.
Parameters Parameters
---------- ----------
...@@ -55,15 +26,22 @@ def get_znorm_normalizer( ...@@ -55,15 +26,22 @@ def get_znorm_normalizer(
dataloader: dataloader:
A torch Dataloader from which to compute the mean and std A torch Dataloader from which to compute the mean and std
Returns Returns
------- -------
An initialized TorchVisionNormalizer An initialized normalizer
""" """
mean = 0.0 # Peek the number of channels of batches in the data loader
var = 0.0 batch = next(iter(dataloader))
channels = batch[0].shape[1]
# Initialises accumulators
mean = torch.zeros(channels, dtype=batch[0].dtype)
var = torch.zeros(channels, dtype=batch[0].dtype)
num_images = 0 num_images = 0
# Evaluates mean and standard deviation
for batch in dataloader: for batch in dataloader:
data = batch[0] data = batch[0]
data = data.view(data.size(0), data.size(1), -1) data = data.view(data.size(0), data.size(1), -1)
...@@ -76,5 +54,21 @@ def get_znorm_normalizer( ...@@ -76,5 +54,21 @@ def get_znorm_normalizer(
var /= num_images var /= num_images
std = torch.sqrt(var) std = torch.sqrt(var)
normalizer = TorchVisionNormalizer(mean, std) return torchvision.transforms.Normalize(mean, std)
return normalizer
def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
"""Returns the stock ImageNet normalisation weights from torchvision.
The weights are wrapped in a torch module. This normalizer only works for
**RGB (color) images**.
Returns
-------
An initialized normalizer
"""
return torchvision.transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
)
...@@ -79,12 +79,7 @@ class PASA(pl.LightningModule): ...@@ -79,12 +79,7 @@ class PASA(pl.LightningModule):
self.dense = nn.Linear(80, 1) # Fully connected layer self.dense = nn.Linear(80, 1) # Fully connected layer
def forward(self, x): def forward(self, x):
if self.normalizer is None: x = self.normalizer(x) # type: ignore
raise TypeError(
"The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model."
)
x = self.normalizer(x)
# First convolution block # First convolution block
_x = x _x = x
...@@ -140,11 +135,11 @@ class PASA(pl.LightningModule): ...@@ -140,11 +135,11 @@ class PASA(pl.LightningModule):
dataloader: dataloader:
A torch Dataloader from which to compute the mean and std A torch Dataloader from which to compute the mean and std
""" """
from .normalizer import get_znorm_normalizer from .normalizer import make_z_normalizer
self.normalizer = get_znorm_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["label"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment