-
André Anjos authoredAndré Anjos authored
normalizer.py 2.04 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Functions to compute normalisation factors based on dataloaders."""
import logging
import torch
import torch.nn
import torch.utils.data
import torchvision.transforms
import tqdm
logger = logging.getLogger(__name__)
def make_z_normalizer(
dataloader: torch.utils.data.DataLoader,
) -> torchvision.transforms.Normalize:
"""Compute mean and standard deviation from a dataloader.
This function will input a dataloader, and compute the mean and standard
deviation by image channel. It will work for both monochromatic, and color
inputs with 2, 3 or more color planes.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
Returns
-------
An initialized normalizer.
"""
# Peek the number of channels of batches in the data loader
batch = next(iter(dataloader))
channels = batch[0]["image"].shape[1]
# Initialises accumulators
mean = torch.zeros(channels, dtype=batch[0]["image"].dtype)
var = torch.zeros(channels, dtype=batch[0]["image"].dtype)
num_images = 0
# Evaluates mean and standard deviation
for batch in tqdm.tqdm(dataloader, unit="batch"):
data = batch[0]["image"]
data = data.view(data.size(0), data.size(1), -1)
num_images += data.size(0)
mean += data.mean(2).sum(0)
var += data.var(2).sum(0)
mean /= num_images
var /= num_images
std = torch.sqrt(var)
return torchvision.transforms.Normalize(mean, std)
def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
"""Return the stock ImageNet normalisation weights from torchvision.
The weights are wrapped in a torch module. This normalizer only works for
**RGB (color) images**.
Returns
-------
An initialized normalizer.
"""
return torchvision.transforms.Normalize(
(0.485, 0.456, 0.406),
(0.229, 0.224, 0.225),
)