# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """A network model that prefixes a z-normalization step to any other module.""" import torch import torch.nn import torch.utils.data class TorchVisionNormalizer(torch.nn.Module): """A simple normalizer that applies the standard torchvision normalization. This module does not learn. Parameters ---------- nb_channels : :py:class:`int`, Optional Number of images channels fed to the model """ def __init__(self, subtract: torch.Tensor, divide: torch.Tensor): super().__init__() assert len(subtract) == len(divide), "TODO" assert len(subtract) in (1, 3), "TODO" self.subtract = subtract self.divided = divide subtract = torch.zeros(len(subtract.shape))[None, :, None, None] divide = torch.ones(len(divide.shape))[None, :, None, None] self.register_buffer("subtract", subtract) self.register_buffer("divide", divide) self.name = "torchvision-normalizer" def forward(self, inputs: torch.Tensor): """inputs shape [batches, planes, height, width]""" return inputs.sub(self.subtract).div(self.divide) def get_znorm_normalizer( dataloader: torch.utils.data.DataLoader, ) -> TorchVisionNormalizer: # TODO: Fix this function to use unaugmented training set # TODO: This function is only applicable IFF we are not fine-tuning (ie. # model does not re-use weights from imagenet training!) # TODO: Add type hints # TODO: Add documentation # 1 extract mean/std from dataloader # 2 return TorchVisionNormalizer(mean, std) pass