From 7eaac22c67c4063c5f2c1d633ecdf0598b2cb956 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 11 Jul 2023 10:58:26 +0200
Subject: [PATCH] Updated models

---
 src/ptbench/data/datamodule.py |   4 +-
 src/ptbench/models/alexnet.py  | 117 ++++++++++++++++++++++-----------
 src/ptbench/models/densenet.py | 108 +++++++++++++++++++-----------
 src/ptbench/models/pasa.py     |  46 ++++++++++++-
 4 files changed, 194 insertions(+), 81 deletions(-)

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 1ae9531e..abcf11d7 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -229,11 +229,11 @@ def _make_balanced_random_sampler(
     1. The probability of picking a sample from any target is the same (0.5 in
        this case).  To verify this, notice that the probability of picking a
        sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
-    2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is
+    2. The probability of picking a sample with ``target=0`` from Dataset 2 is
        3 times higher than those from Dataset 1.  As there are 3 times less
        samples in Dataset 2 with ``target=0``, this makes choosing samples from
        Dataset 1 proportionally less likely.
-    3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is
+    3. The probability of picking a sample with ``target=1`` from Dataset 2 is
        3 times lower than those from Dataset 1.  As there are 3 times less
        samples in Dataset 1 with ``target=1``, this makes choosing samples from
        Dataset 2 proportionally less likely.
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index e8643b46..a878a076 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -3,14 +3,18 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import logging
+import typing
 
 import lightning.pytorch as pl
-import torch.nn as nn
+import torch
+import torch.nn
+import torch.nn.functional as F
+import torch.optim.optimizer
 import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
-from ..data.typing import DataLoader
+from ..data.typing import DataLoader, TransformSequence
 
 logger = logging.getLogger(__name__)
 
@@ -19,32 +23,66 @@ class Alexnet(pl.LightningModule):
     """Alexnet module.
 
     Note: only usable with a normalized dataset
+
+    Parameters
+    ----------
+
+    train_loss
+        The loss to be used during the training.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    validation_loss
+        The loss to be used for validation (may be different from the training
+        loss).  If extra-validation sets are provided, the same loss will be
+        used throughout.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    optimizer_type
+        The type of optimizer to use for training
+
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+
+    pretrained
+        If set to True, loads pretrained model weights during initialization, else trains a new model.
     """
 
     def __init__(
         self,
-        criterion=None,
-        criterion_valid=None,
-        optimizer=None,
-        optimizer_configs=None,
-        pretrained=False,
-        augmentation_transforms=[],
+        train_loss: torch.nn.Module,
+        validation_loss: torch.nn.Module | None,
+        optimizer_type: type[torch.optim.Optimizer],
+        optimizer_arguments: dict[str, typing.Any],
+        augmentation_transforms: TransformSequence = [],
+        pretrained: bool = False,
     ):
         super().__init__()
 
         self.name = "alexnet"
 
-        self.augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms
+        self._train_loss = train_loss
+        self._validation_loss = (
+            validation_loss if validation_loss is not None else train_loss
         )
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
-        self.optimizer = optimizer
-        self.optimizer_configs = optimizer_configs
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
+        )
 
-        self.normalizer = None
         self.pretrained = pretrained
 
         # Load pretrained model
@@ -57,11 +95,12 @@ class Alexnet(pl.LightningModule):
         self.model_ft = models.alexnet(weights=weights)
 
         # Adapt output features
-        self.model_ft.classifier[4] = nn.Linear(4096, 512)
-        self.model_ft.classifier[6] = nn.Linear(512, 1)
+        self.model_ft.classifier[4] = torch.nn.Linear(4096, 512)
+        self.model_ft.classifier[6] = torch.nn.Linear(512, 1)
 
     def forward(self, x):
-        x = self.normalizer(x)
+        x = self.normalizer(x)  # type: ignore
+
         x = self.model_ft(x)
 
         return x
@@ -121,25 +160,25 @@ class Alexnet(pl.LightningModule):
         """
         from .loss_weights import get_label_weights
 
-        if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss training criterion.")
+        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss training loss.")
             weights = get_label_weights(train_dataloader)
-            self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
         else:
             raise RuntimeError(
                 "Training loss is not BCEWithLogitsLoss - dunno how to balance"
             )
 
-        if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
+        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
             weights = get_label_weights(valid_dataloader)
-            self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights)
+            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
         else:
             raise RuntimeError(
                 "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
             )
 
-    def training_step(self, batch, batch_idx):
+    def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
 
@@ -150,15 +189,13 @@ class Alexnet(pl.LightningModule):
 
         # Forward pass on the network
         augmented_images = [
-            self.augmentation_transforms(img).to(self.device) for img in images
+            self._augmentation_transforms(img).to(self.device) for img in images
         ]
         # Combine list of augmented images back into a tensor
         augmented_images = torch.cat(augmented_images, 0).view(images.shape)
         outputs = self(augmented_images)
 
-        training_loss = self.criterion(outputs, labels.float())
-
-        return {"loss": training_loss}
+        return self._train_loss(outputs, labels.float())
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[0]
@@ -172,23 +209,23 @@ class Alexnet(pl.LightningModule):
         # data forwarding on the existing network
         outputs = self(images)
 
-        validation_loss = self.criterion_valid(outputs, labels.float())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-        else:
-            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         images = batch[0]
         labels = batch[1]["label"]
-        names = batch[1]["name"]
+        names = batch[1]["names"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
 
-        return names[0], torch.flatten(probabilities), torch.flatten(labels)
+        return (
+            names[0],
+            torch.flatten(probabilities),
+            torch.flatten(labels),
+        )
 
     def configure_optimizers(self):
-        optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
-        return optimizer
+        return self._optimizer_type(
+            self.parameters(), **self._optimizer_arguments
+        )
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 25d1d8ff..8eba3b53 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -3,15 +3,18 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import logging
+import typing
 
 import lightning.pytorch as pl
 import torch
 import torch.nn
+import torch.nn.functional as F
+import torch.optim.optimizer
 import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
-from ..data.typing import DataLoader
+from ..data.typing import DataLoader, TransformSequence
 
 logger = logging.getLogger(__name__)
 
@@ -22,32 +25,61 @@ class Densenet(pl.LightningModule):
     Parameters
     ----------
 
-    criterion
-        A dictionary containing the criteria for the
+    train_loss
+        The loss to be used during the training.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    validation_loss
+        The loss to be used for validation (may be different from the training
+        loss).  If extra-validation sets are provided, the same loss will be
+        used throughout.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+
+    optimizer_type
+        The type of optimizer to use for training
+
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+
+    pretrained
+        If set to True, loads pretrained model weights during initialization, else trains a new model.
     """
 
     def __init__(
         self,
-        criterion=None,
-        criterion_valid=None,
-        optimizer=None,
-        optimizer_configs=None,
-        pretrained=False,
-        augmentation_transforms=[],
+        train_loss: torch.nn.Module,
+        validation_loss: torch.nn.Module | None,
+        optimizer_type: type[torch.optim.Optimizer],
+        optimizer_arguments: dict[str, typing.Any],
+        augmentation_transforms: TransformSequence = [],
+        pretrained: bool= False,
     ):
         super().__init__()
 
         self.name = "densenet-121"
 
-        self.augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms
+        self._train_loss = train_loss
+        self._validation_loss = (
+            validation_loss if validation_loss is not None else train_loss
         )
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
 
-        self.criterion = criterion
-        self.criterion_valid = criterion_valid
-
-        self.optimizer = optimizer
-        self.optimizer_configs = optimizer_configs
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms
+        )
 
         self.pretrained = pretrained
 
@@ -66,7 +98,9 @@ class Densenet(pl.LightningModule):
         )
 
     def forward(self, x):
-        x = self.normalizer(x)
+        
+        x = self.normalizer(x)  # type: ignore
+
         x = self.model_ft(x)
 
         return x
@@ -128,25 +162,25 @@ class Densenet(pl.LightningModule):
         """
         from .loss_weights import get_label_weights
 
-        if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss training criterion.")
+        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss training loss.")
             weights = get_label_weights(train_dataloader)
-            self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
         else:
             raise RuntimeError(
                 "Training loss is not BCEWithLogitsLoss - dunno how to balance"
             )
 
-        if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
+        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
             weights = get_label_weights(valid_dataloader)
-            self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights)
+            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
         else:
             raise RuntimeError(
                 "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
             )
 
-    def training_step(self, batch, batch_idx):
+    def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
 
@@ -157,15 +191,13 @@ class Densenet(pl.LightningModule):
 
         # Forward pass on the network
         augmented_images = [
-            self.augmentation_transforms(img).to(self.device) for img in images
+            self._augmentation_transforms(img).to(self.device) for img in images
         ]
         # Combine list of augmented images back into a tensor
         augmented_images = torch.cat(augmented_images, 0).view(images.shape)
         outputs = self(augmented_images)
 
-        training_loss = self.criterion(outputs, labels.float())
-
-        return {"loss": training_loss}
+        return self._train_loss(outputs, labels.float())
 
     def validation_step(self, batch, batch_idx, dataloader_idx=0):
         images = batch[0]
@@ -179,23 +211,23 @@ class Densenet(pl.LightningModule):
         # data forwarding on the existing network
         outputs = self(images)
 
-        validation_loss = self.criterion_valid(outputs, labels.float())
-
-        if dataloader_idx == 0:
-            return {"validation_loss": validation_loss}
-        else:
-            return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
+        return self._validation_loss(outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
         images = batch[0]
         labels = batch[1]["label"]
-        names = batch[1]["name"]
+        names = batch[1]["names"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
 
-        return names[0], torch.flatten(probabilities), torch.flatten(labels)
+        return (
+            names[0],
+            torch.flatten(probabilities),
+            torch.flatten(labels),
+        )
 
     def configure_optimizers(self):
-        optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
-        return optimizer
+        return self._optimizer_type(
+            self.parameters(), **self._optimizer_arguments
+        )
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 34e5c67f..20bbb0dd 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -13,7 +13,7 @@ import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
 
-from ..data.typing import TransformSequence
+from ..data.typing import DataLoader, TransformSequence
 
 logger = logging.getLogger(__name__)
 
@@ -202,6 +202,50 @@ class Pasa(pl.LightningModule):
         )
         self.normalizer = make_z_normalizer(dataloader)
 
+    def balance_losses_by_class(
+        self,
+        train_dataloader: DataLoader,
+        valid_dataloader: dict[str, DataLoader],
+    ):
+        """Reweights loss weights if possible.
+
+        Parameters
+        ----------
+
+        train_dataloader
+            The data loader to use for training
+
+        valid_dataloader
+            The data loaders to use for each of the validation sets
+
+
+        Raises
+        ------
+
+        RuntimeError
+            If train or validation losses are not of type
+            :py:class:`torch.nn.BCEWithLogitsLoss`.
+        """
+        from .loss_weights import get_label_weights
+
+        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss training loss.")
+            weights = get_label_weights(train_dataloader)
+            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
+        else:
+            raise RuntimeError(
+                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
+            )
+
+        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
+            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
+            weights = get_label_weights(valid_dataloader)
+            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
+        else:
+            raise RuntimeError(
+                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
+            )
+
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
-- 
GitLab