# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision.models as models

from .normalizer import TorchVisionNormalizer


class Densenet(pl.LightningModule):
    """Densenet module.

    Note: only usable with a normalized dataset
    """

    def __init__(
        self,
        criterion,
        criterion_valid,
        optimizer,
        optimizer_configs,
        pretrained=False,
        nb_channels=3,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.name = "Densenet"

        self.criterion = criterion
        self.criterion_valid = criterion_valid

        self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)

        # Load pretrained model
        weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
        self.model_ft = models.densenet121(weights=weights)

        # Adapt output features
        self.model_ft.classifier = nn.Sequential(
            nn.Linear(1024, 256), nn.Linear(256, 1)
        )

    def forward(self, x):
        """

        Parameters
        ----------

        x : list
            list of tensors.

        Returns
        -------

        tensor : :py:class:`torch.Tensor`

        """
        x = self.normalizer(x)

        x = self.model_ft(x)

        return x

    def training_step(self, batch, batch_idx):
        images = batch[1]
        labels = batch[2]

        # Increase label dimension if too low
        # Allows single and multiclass usage
        if labels.ndim == 1:
            labels = torch.reshape(labels, (labels.shape[0], 1))

        # Forward pass on the network
        outputs = self(images)

        training_loss = self.criterion(outputs, labels.double())

        return {"loss": training_loss}

    def validation_step(self, batch, batch_idx):
        images = batch[1]
        labels = batch[2]

        # Increase label dimension if too low
        # Allows single and multiclass usage
        if labels.ndim == 1:
            labels = torch.reshape(labels, (labels.shape[0], 1))

        # data forwarding on the existing network
        outputs = self(images)
        validation_loss = self.criterion_valid(outputs, labels.double())

        return {"validation_loss": validation_loss}

    def predict_step(self, batch, batch_idx, grad_cams=False):
        names = batch[0]
        images = batch[1]

        outputs = self(images)
        probabilities = torch.sigmoid(outputs)

        # necessary check for HED architecture that uses several outputs
        # for loss calculation instead of just the last concatfuse block
        if isinstance(outputs, list):
            outputs = outputs[-1]

        return names[0], torch.flatten(probabilities), torch.flatten(batch[2])

    def configure_optimizers(self):
        # Dynamically instantiates the optimizer given the configs
        optimizer = getattr(torch.optim, self.hparams.optimizer)(
            self.parameters(), **self.hparams.optimizer_params
        )

        return optimizer