# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """A network model that prefixes a subtract/divide step to any other module.""" import torch import torch.nn import torch.utils.data import torchvision.transforms def make_z_normalizer( dataloader: torch.utils.data.DataLoader, ) -> torchvision.transforms.Normalize: """Computes mean and standard deviation from a dataloader. This function will input a dataloader, and compute the mean and standard deviation by image channel. It will work for both monochromatic, and color inputs with 2, 3 or more color planes. Parameters ---------- dataloader: A torch Dataloader from which to compute the mean and std Returns ------- An initialized normalizer """ # Peek the number of channels of batches in the data loader batch = next(iter(dataloader)) channels = batch[0].shape[1] # Initialises accumulators mean = torch.zeros(channels, dtype=batch[0].dtype) var = torch.zeros(channels, dtype=batch[0].dtype) num_images = 0 # Evaluates mean and standard deviation for batch in dataloader: data = batch[0] data = data.view(data.size(0), data.size(1), -1) num_images += data.size(0) mean += data.mean(2).sum(0) var += data.var(2).sum(0) mean /= num_images var /= num_images std = torch.sqrt(var) return torchvision.transforms.Normalize(mean, std) def make_imagenet_normalizer() -> torchvision.transforms.Normalize: """Returns the stock ImageNet normalisation weights from torchvision. The weights are wrapped in a torch module. This normalizer only works for **RGB (color) images**. Returns ------- An initialized normalizer """ return torchvision.transforms.Normalize( (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) )