From 5fae86b86ed7adb9687a6971178e9a8eeab9532d Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 29 Apr 2024 10:57:43 +0200
Subject: [PATCH] [model] Use base model

---
 src/mednet/models/alexnet.py  | 61 +++++----------------------
 src/mednet/models/densenet.py | 61 +++++----------------------
 src/mednet/models/pasa.py     | 78 +++++------------------------------
 3 files changed, 30 insertions(+), 170 deletions(-)

diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index b4b9e723..22b98baa 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -5,7 +5,6 @@
 import logging
 import typing
 
-import lightning.pytorch as pl
 import torch
 import torch.nn
 import torch.optim.optimizer
@@ -14,14 +13,14 @@ import torchvision.models as models
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .model import Model
 from .separate import separate
 from .transforms import RGB, SquareCenterPad
-from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
 
 
-class Alexnet(pl.LightningModule):
+class Alexnet(Model):
     """Alexnet module.
 
     Note: only usable with a normalized dataset
@@ -68,7 +67,14 @@ class Alexnet(pl.LightningModule):
         pretrained: bool = False,
         num_classes: int = 1,
     ):
-        super().__init__()
+        super().__init__(
+            train_loss,
+            validation_loss,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
 
         self.name = "alexnet"
         self.num_classes = num_classes
@@ -79,17 +85,6 @@ class Alexnet(pl.LightningModule):
             RGB(),
         ]
 
-        self._train_loss = train_loss
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        )
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self._augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms,
-        )
-
         self.pretrained = pretrained
 
         # Load pretrained model
@@ -109,36 +104,6 @@ class Alexnet(pl.LightningModule):
         x = self.normalizer(x)  # type: ignore
         return self.model_ft(x)
 
-    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during checkpoint saving (called by lightning).
-
-        Called by Lightning when saving a checkpoint to give you a chance to
-        store anything else you might want to save. Use on_load_checkpoint() to
-        restore what additional data is saved here.
-
-        Parameters
-        ----------
-        checkpoint
-            The checkpoint to save.
-        """
-
-        checkpoint["normalizer"] = self.normalizer
-
-    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during model loading (called by lightning).
-
-        If you saved something with on_save_checkpoint() this is your chance to
-        restore this.
-
-        Parameters
-        ----------
-        checkpoint
-            The loaded checkpoint.
-        """
-
-        logger.info("Restoring normalizer from checkpoint.")
-        self.normalizer = checkpoint["normalizer"]
-
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
         """Initialize the normalizer for the current model.
 
@@ -208,9 +173,3 @@ class Alexnet(pl.LightningModule):
         outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
         return separate((probabilities, batch[1]))
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index f7d15441..fcdb9f95 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -5,7 +5,6 @@
 import logging
 import typing
 
-import lightning.pytorch as pl
 import torch
 import torch.nn
 import torch.optim.optimizer
@@ -14,14 +13,14 @@ import torchvision.models as models
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .model import Model
 from .separate import separate
 from .transforms import RGB, SquareCenterPad
-from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
 
 
-class Densenet(pl.LightningModule):
+class Densenet(Model):
     """Densenet-121 module.
 
     Parameters
@@ -69,7 +68,14 @@ class Densenet(pl.LightningModule):
         dropout: float = 0.1,
         num_classes: int = 1,
     ):
-        super().__init__()
+        super().__init__(
+            train_loss,
+            validation_loss,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
 
         self.name = "densenet-121"
         self.num_classes = num_classes
@@ -80,17 +86,6 @@ class Densenet(pl.LightningModule):
             RGB(),
         ]
 
-        self._train_loss = train_loss
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        )
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self._augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms,
-        )
-
         self.pretrained = pretrained
 
         # Load pretrained model
@@ -112,36 +107,6 @@ class Densenet(pl.LightningModule):
         x = self.normalizer(x)  # type: ignore
         return self.model_ft(x)
 
-    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during checkpoint saving (called by lightning).
-
-        Called by Lightning when saving a checkpoint to give you a chance to
-        store anything else you might want to save. Use on_load_checkpoint() to
-        restore what additional data is saved here.
-
-        Parameters
-        ----------
-        checkpoint
-            The checkpoint to save.
-        """
-
-        checkpoint["normalizer"] = self.normalizer
-
-    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during model loading (called by lightning).
-
-        If you saved something with on_save_checkpoint() this is your chance to
-        restore this.
-
-        Parameters
-        ----------
-        checkpoint
-            The loaded checkpoint.
-        """
-
-        logger.info("Restoring normalizer from checkpoint.")
-        self.normalizer = checkpoint["normalizer"]
-
     def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
         """Initialize the normalizer for the current model.
 
@@ -205,9 +170,3 @@ class Densenet(pl.LightningModule):
         outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
         return separate((probabilities, batch[1]))
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index 16a71f73..389eac8c 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -5,7 +5,6 @@
 import logging
 import typing
 
-import lightning.pytorch as pl
 import torch
 import torch.nn
 import torch.nn.functional as F  # noqa: N812
@@ -14,14 +13,14 @@ import torch.utils.data
 import torchvision.transforms
 
 from ..data.typing import TransformSequence
+from .model import Model
 from .separate import separate
 from .transforms import Grayscale, SquareCenterPad
-from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
 
 
-class Pasa(pl.LightningModule):
+class Pasa(Model):
     """Implementation of CNN by Pasa and others.
 
     Simple CNN for classification based on paper by [PASA-2019]_.
@@ -67,7 +66,14 @@ class Pasa(pl.LightningModule):
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
-        super().__init__()
+        super().__init__(
+            train_loss,
+            validation_loss,
+            optimizer_type,
+            optimizer_arguments,
+            augmentation_transforms,
+            num_classes,
+        )
 
         self.name = "pasa"
         self.num_classes = num_classes
@@ -82,17 +88,6 @@ class Pasa(pl.LightningModule):
             ),
         ]
 
-        self._train_loss = train_loss
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        )
-        self._optimizer_type = optimizer_type
-        self._optimizer_arguments = optimizer_arguments
-
-        self._augmentation_transforms = torchvision.transforms.Compose(
-            augmentation_transforms,
-        )
-
         # First convolution block
         self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
         self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
@@ -213,53 +208,6 @@ class Pasa(pl.LightningModule):
 
         # x = F.log_softmax(x, dim=1) # 0 is batch size
 
-    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during checkpoint saving (called by lightning).
-
-        Called by Lightning when saving a checkpoint to give you a chance to
-        store anything else you might want to save. Use on_load_checkpoint() to
-        restore what additional data is saved here.
-
-        Parameters
-        ----------
-        checkpoint
-            The checkpoint to save.
-        """
-
-        checkpoint["normalizer"] = self.normalizer
-
-    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
-        """Perform actions during model loading (called by lightning).
-
-        If you saved something with on_save_checkpoint() this is your chance to
-        restore this.
-
-        Parameters
-        ----------
-        checkpoint
-            The loaded checkpoint.
-        """
-
-        logger.info("Restoring normalizer from checkpoint.")
-        self.normalizer = checkpoint["normalizer"]
-
-    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
-        """Initialize the input normalizer for the current model.
-
-        Parameters
-        ----------
-        dataloader
-            A torch Dataloader from which to compute the mean and std.
-        """
-
-        from .normalizer import make_z_normalizer
-
-        logger.info(
-            f"Uninitialised {self.name} model - "
-            f"computing z-norm factors from train dataloader.",
-        )
-        self.normalizer = make_z_normalizer(dataloader)
-
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
@@ -292,9 +240,3 @@ class Pasa(pl.LightningModule):
         outputs = self(batch[0])
         probabilities = torch.sigmoid(outputs)
         return separate((probabilities, batch[1]))
-
-    def configure_optimizers(self):
-        return self._optimizer_type(
-            self.parameters(),
-            **self._optimizer_arguments,
-        )
-- 
GitLab