From 21937fb23c7c0beeffb27b61347827e700f1a6f6 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 29 Apr 2024 10:26:07 +0200
Subject: [PATCH] [model] Create base Model class

---
 src/mednet/models/model.py | 143 +++++++++++++++++++++++++++++++++++++
 1 file changed, 143 insertions(+)
 create mode 100644 src/mednet/models/model.py

diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
new file mode 100644
index 00000000..50e314bb
--- /dev/null
+++ b/src/mednet/models/model.py
@@ -0,0 +1,143 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+import typing
+
+import lightning.pytorch as pl
+import torch
+import torch.nn
+import torch.optim.optimizer
+import torch.utils.data
+import torchvision.transforms
+
+from ..data.typing import TransformSequence
+from .typing import Checkpoint
+
+logger = logging.getLogger(__name__)
+
+
+class Model(pl.LightningModule):
+    """Base class for models.
+
+    Parameters
+    ----------
+    train_loss
+        The loss to be used during the training.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+    validation_loss
+        The loss to be used for validation (may be different from the training
+        loss).  If extra-validation sets are provided, the same loss will be
+        used throughout.
+
+        .. warning::
+
+           The loss should be set to always return batch averages (as opposed
+           to the batch sum), as our logging system expects it so.
+    optimizer_type
+        The type of optimizer to use for training.
+    optimizer_arguments
+        Arguments to the optimizer after ``params``.
+    augmentation_transforms
+        An optional sequence of torch modules containing transforms to be
+        applied on the input **before** it is fed into the network.
+    num_classes
+        Number of outputs (classes) for this model.
+    """
+
+    def __init__(
+        self,
+        train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
+        validation_loss: torch.nn.Module | None = None,
+        optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
+        optimizer_arguments: dict[str, typing.Any] = {},
+        augmentation_transforms: TransformSequence = [],
+        num_classes: int = 1,
+    ):
+        super().__init__()
+
+        self.name = "model"
+        self.num_classes = num_classes
+
+        self.model_transforms: TransformSequence = []
+
+        self._train_loss = train_loss
+        self._validation_loss = (
+            validation_loss if validation_loss is not None else train_loss
+        )
+        self._optimizer_type = optimizer_type
+        self._optimizer_arguments = optimizer_arguments
+
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            augmentation_transforms,
+        )
+
+    def forward(self, x):
+        raise NotImplementedError
+
+    def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Perform actions during checkpoint saving (called by lightning).
+
+        Called by Lightning when saving a checkpoint to give you a chance to
+        store anything else you might want to save. Use on_load_checkpoint() to
+        restore what additional data is saved here.
+
+        Parameters
+        ----------
+        checkpoint
+            The checkpoint to save.
+        """
+
+        checkpoint["normalizer"] = self.normalizer
+
+    def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
+        """Perform actions during model loading (called by lightning).
+
+        If you saved something with on_save_checkpoint() this is your chance to
+        restore this.
+
+        Parameters
+        ----------
+        checkpoint
+            The loaded checkpoint.
+        """
+
+        logger.info("Restoring normalizer from checkpoint.")
+        self.normalizer = checkpoint["normalizer"]
+
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """Initialize the input normalizer for the current model.
+
+        Parameters
+        ----------
+        dataloader
+            A torch Dataloader from which to compute the mean and std.
+        """
+
+        from .normalizer import make_z_normalizer
+
+        logger.info(
+            f"Uninitialised {self.name} model - "
+            f"computing z-norm factors from train dataloader.",
+        )
+        self.normalizer = make_z_normalizer(dataloader)
+
+    def training_step(self, batch, _):
+        raise NotImplementedError
+
+    def validation_step(self, batch, batch_idx, dataloader_idx=0):
+        raise NotImplementedError
+
+    def predict_step(self, batch, batch_idx, dataloader_idx=0):
+        raise NotImplementedError
+
+    def configure_optimizers(self):
+        return self._optimizer_type(
+            self.parameters(),
+            **self._optimizer_arguments,
+        )
-- 
GitLab