From 21937fb23c7c0beeffb27b61347827e700f1a6f6 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 29 Apr 2024 10:26:07 +0200 Subject: [PATCH] [model] Create base Model class --- src/mednet/models/model.py | 143 +++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 src/mednet/models/model.py diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py new file mode 100644 index 00000000..50e314bb --- /dev/null +++ b/src/mednet/models/model.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import typing + +import lightning.pytorch as pl +import torch +import torch.nn +import torch.optim.optimizer +import torch.utils.data +import torchvision.transforms + +from ..data.typing import TransformSequence +from .typing import Checkpoint + +logger = logging.getLogger(__name__) + + +class Model(pl.LightningModule): + """Base class for models. + + Parameters + ---------- + train_loss + The loss to be used during the training. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + validation_loss + The loss to be used for validation (may be different from the training + loss). If extra-validation sets are provided, the same loss will be + used throughout. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + """ + + def __init__( + self, + train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), + validation_loss: torch.nn.Module | None = None, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + ): + super().__init__() + + self.name = "model" + self.num_classes = num_classes + + self.model_transforms: TransformSequence = [] + + self._train_loss = train_loss + self._validation_loss = ( + validation_loss if validation_loss is not None else train_loss + ) + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments + + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms, + ) + + def forward(self, x): + raise NotImplementedError + + def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: + """Perform actions during checkpoint saving (called by lightning). + + Called by Lightning when saving a checkpoint to give you a chance to + store anything else you might want to save. Use on_load_checkpoint() to + restore what additional data is saved here. + + Parameters + ---------- + checkpoint + The checkpoint to save. + """ + + checkpoint["normalizer"] = self.normalizer + + def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: + """Perform actions during model loading (called by lightning). + + If you saved something with on_save_checkpoint() this is your chance to + restore this. + + Parameters + ---------- + checkpoint + The loaded checkpoint. + """ + + logger.info("Restoring normalizer from checkpoint.") + self.normalizer = checkpoint["normalizer"] + + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the input normalizer for the current model. + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + """ + + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader.", + ) + self.normalizer = make_z_normalizer(dataloader) + + def training_step(self, batch, _): + raise NotImplementedError + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + raise NotImplementedError + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + raise NotImplementedError + + def configure_optimizers(self): + return self._optimizer_type( + self.parameters(), + **self._optimizer_arguments, + ) -- GitLab