# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""A network model that prefixes a z-normalization step to any other module."""

import torch
import torch.nn
import torch.utils.data


class TorchVisionNormalizer(torch.nn.Module):
    """A simple normalizer that applies the standard torchvision normalization.

    This module does not learn.

    Parameters
    ----------

    nb_channels : :py:class:`int`, Optional
        Number of images channels fed to the model
    """

    def __init__(self, subtract: torch.Tensor, divide: torch.Tensor):
        super().__init__()
        assert len(subtract) == len(divide), "TODO"
        assert len(subtract) in (1, 3), "TODO"
        self.subtract = subtract
        self.divided = divide
        subtract = torch.zeros(len(subtract.shape))[None, :, None, None]
        divide = torch.ones(len(divide.shape))[None, :, None, None]
        self.register_buffer("subtract", subtract)
        self.register_buffer("divide", divide)
        self.name = "torchvision-normalizer"

    def forward(self, inputs: torch.Tensor):
        """inputs shape [batches, planes, height, width]"""
        return inputs.sub(self.subtract).div(self.divide)


def get_znorm_normalizer(
    dataloader: torch.utils.data.DataLoader,
) -> TorchVisionNormalizer:
    # TODO: Fix this function to use unaugmented training set
    # TODO: This function is only applicable IFF we are not fine-tuning (ie.
    #       model does not re-use weights from imagenet training!)
    # TODO: Add type hints
    # TODO: Add documentation

    # 1 extract mean/std from dataloader

    # 2 return TorchVisionNormalizer(mean, std)
    pass