Skip to content
Snippets Groups Projects
pasa.py 8.01 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data


class PASA(pl.LightningModule):
    """PASA module.

    Based on paper by [PASA-2019]_.
    """

    def __init__(
        self,
        criterion,
        criterion_valid,
        optimizer,
        optimizer_configs,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.name = "pasa"

        self.normalizer = None

        # First convolution block
        self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
        self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
        self.fc3 = nn.Conv2d(1, 16, (1, 1), (4, 4))

        self.batchNorm2d_4 = nn.BatchNorm2d(4)
        self.batchNorm2d_16 = nn.BatchNorm2d(16)
        self.batchNorm2d_16_2 = nn.BatchNorm2d(16)

        # Second convolution block
        self.fc4 = nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1))
        self.fc5 = nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1))
        self.fc6 = nn.Conv2d(16, 32, (1, 1), (1, 1))  # Original stride (2, 2)

        self.batchNorm2d_24 = nn.BatchNorm2d(24)
        self.batchNorm2d_32 = nn.BatchNorm2d(32)
        self.batchNorm2d_32_2 = nn.BatchNorm2d(32)

        # Third convolution block
        self.fc7 = nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1))
        self.fc8 = nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1))
        self.fc9 = nn.Conv2d(32, 48, (1, 1), (1, 1))  # Original stride (2, 2)

        self.batchNorm2d_40 = nn.BatchNorm2d(40)
        self.batchNorm2d_48 = nn.BatchNorm2d(48)
        self.batchNorm2d_48_2 = nn.BatchNorm2d(48)

        # Fourth convolution block
        self.fc10 = nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1))
        self.fc11 = nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1))
        self.fc12 = nn.Conv2d(48, 64, (1, 1), (1, 1))  # Original stride (2, 2)

        self.batchNorm2d_56 = nn.BatchNorm2d(56)
        self.batchNorm2d_64 = nn.BatchNorm2d(64)
        self.batchNorm2d_64_2 = nn.BatchNorm2d(64)

        # Fifth convolution block
        self.fc13 = nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1))
        self.fc14 = nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1))
        self.fc15 = nn.Conv2d(64, 80, (1, 1), (1, 1))  # Original stride (2, 2)

        self.batchNorm2d_72 = nn.BatchNorm2d(72)
        self.batchNorm2d_80 = nn.BatchNorm2d(80)
        self.batchNorm2d_80_2 = nn.BatchNorm2d(80)

        self.pool2d = nn.MaxPool2d((3, 3), (2, 2))  # Pool after conv. block
        self.dense = nn.Linear(80, 1)  # Fully connected layer

    def forward(self, x):
        if self.normalizer is None:
            raise TypeError(
                "The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model."
            )

        x = self.normalizer(x)

        # First convolution block
        _x = x
        x = F.relu(self.batchNorm2d_4(self.fc1(x)))  # 1st convolution
        x = F.relu(self.batchNorm2d_16(self.fc2(x)))  # 2nd convolution
        x = (x + F.relu(self.batchNorm2d_16_2(self.fc3(_x)))) / 2  # Parallel
        x = self.pool2d(x)  # Pooling

        # Second convolution block
        _x = x
        x = F.relu(self.batchNorm2d_24(self.fc4(x)))  # 1st convolution
        x = F.relu(self.batchNorm2d_32(self.fc5(x)))  # 2nd convolution
        x = (x + F.relu(self.batchNorm2d_32_2(self.fc6(_x)))) / 2  # Parallel
        x = self.pool2d(x)  # Pooling

        # Third convolution block
        _x = x
        x = F.relu(self.batchNorm2d_40(self.fc7(x)))  # 1st convolution
        x = F.relu(self.batchNorm2d_48(self.fc8(x)))  # 2nd convolution
        x = (x + F.relu(self.batchNorm2d_48_2(self.fc9(_x)))) / 2  # Parallel
        x = self.pool2d(x)  # Pooling

        # Fourth convolution block
        _x = x
        x = F.relu(self.batchNorm2d_56(self.fc10(x)))  # 1st convolution
        x = F.relu(self.batchNorm2d_64(self.fc11(x)))  # 2nd convolution
        x = (x + F.relu(self.batchNorm2d_64_2(self.fc12(_x)))) / 2  # Parallel
        x = self.pool2d(x)  # Pooling

        # Fifth convolution block
        _x = x
        x = F.relu(self.batchNorm2d_72(self.fc13(x)))  # 1st convolution
        x = F.relu(self.batchNorm2d_80(self.fc14(x)))  # 2nd convolution
        x = (x + F.relu(self.batchNorm2d_80_2(self.fc15(_x)))) / 2  # Parallel
        # no pooling

        # Global average pooling
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)

        # Dense layer
        x = self.dense(x)

        # x = F.log_softmax(x, dim=1) # 0 is batch size

        return x

    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
        """Initializes the normalizer for the current model.

        Parameters
        ----------

        dataloader:
            A torch Dataloader from which to compute the mean and std
        """
        from .normalizer import get_znorm_normalizer

        self.normalizer = get_znorm_normalizer(dataloader)

    def training_step(self, batch, batch_idx):
        images = batch[0]
        labels = batch[1]["label"]

        # Increase label dimension if too low
        # Allows single and multiclass usage
        if labels.ndim == 1:
            labels = torch.reshape(labels, (labels.shape[0], 1))

        # Forward pass on the network
        outputs = self(images)

        # Manually move criterion to selected device, since not part of the model.
        self.hparams.criterion = self.hparams.criterion.to(self.device)
        training_loss = self.hparams.criterion(outputs, labels.double())

        return {"loss": training_loss}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        images = batch[0]
        labels = batch[1]["label"]

        # Increase label dimension if too low
        # Allows single and multiclass usage
        if labels.ndim == 1:
            labels = torch.reshape(labels, (labels.shape[0], 1))

        # data forwarding on the existing network
        outputs = self(images)

        # Manually move criterion to selected device, since not part of the model.
        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
            self.device
        )
        validation_loss = self.hparams.criterion_valid(outputs, labels.double())

        if dataloader_idx == 0:
            return {"validation_loss": validation_loss}
        else:
            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}

    def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
        images = batch[0]
        labels = batch[1]["label"]
        names = batch[1]["names"]

        outputs = self(images)
        probabilities = torch.sigmoid(outputs)

        # necessary check for HED architecture that uses several outputs
        # for loss calculation instead of just the last concatfuse block
        if isinstance(outputs, list):
            outputs = outputs[-1]

        results = (
            names[0],
            torch.flatten(probabilities),
            torch.flatten(labels),
        )

        return results
        # {
        # f"dataloader_{dataloader_idx}_predictions": (
        #    names[0],
        #    torch.flatten(probabilities),
        #    torch.flatten(labels),
        # )
        # }

    # def on_predict_epoch_end(self):

    #    retval = defaultdict(list)

    #    for dataloader_name, predictions in self.predictions_cache.items():
    #        for prediction in predictions:
    #            retval[dataloader_name]["name"].append(prediction[0])
    #            retval[dataloader_name]["prediction"].append(prediction[1])
    #            retval[dataloader_name]["label"].append(prediction[2])

    # Need to cache predictions in the predict step, then reorder by key
    # Clear prediction dict
    # raise NotImplementedError

    def configure_optimizers(self):
        # Dynamically instantiates the optimizer given the configs
        optimizer = getattr(torch.optim, self.hparams.optimizer)(
            self.parameters(), **self.hparams.optimizer_configs
        )

        return optimizer