From 5fae86b86ed7adb9687a6971178e9a8eeab9532d Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 29 Apr 2024 10:57:43 +0200 Subject: [PATCH] [model] Use base model --- src/mednet/models/alexnet.py | 61 +++++---------------------- src/mednet/models/densenet.py | 61 +++++---------------------- src/mednet/models/pasa.py | 78 +++++------------------------------ 3 files changed, 30 insertions(+), 170 deletions(-) diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index b4b9e723..22b98baa 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -5,7 +5,6 @@ import logging import typing -import lightning.pytorch as pl import torch import torch.nn import torch.optim.optimizer @@ -14,14 +13,14 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad -from .typing import Checkpoint logger = logging.getLogger(__name__) -class Alexnet(pl.LightningModule): +class Alexnet(Model): """Alexnet module. Note: only usable with a normalized dataset @@ -68,7 +67,14 @@ class Alexnet(pl.LightningModule): pretrained: bool = False, num_classes: int = 1, ): - super().__init__() + super().__init__( + train_loss, + validation_loss, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) self.name = "alexnet" self.num_classes = num_classes @@ -79,17 +85,6 @@ class Alexnet(pl.LightningModule): RGB(), ] - self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) - self._optimizer_type = optimizer_type - self._optimizer_arguments = optimizer_arguments - - self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms, - ) - self.pretrained = pretrained # Load pretrained model @@ -109,36 +104,6 @@ class Alexnet(pl.LightningModule): x = self.normalizer(x) # type: ignore return self.model_ft(x) - def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during checkpoint saving (called by lightning). - - Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. Use on_load_checkpoint() to - restore what additional data is saved here. - - Parameters - ---------- - checkpoint - The checkpoint to save. - """ - - checkpoint["normalizer"] = self.normalizer - - def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during model loading (called by lightning). - - If you saved something with on_save_checkpoint() this is your chance to - restore this. - - Parameters - ---------- - checkpoint - The loaded checkpoint. - """ - - logger.info("Restoring normalizer from checkpoint.") - self.normalizer = checkpoint["normalizer"] - def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initialize the normalizer for the current model. @@ -208,9 +173,3 @@ class Alexnet(pl.LightningModule): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) - - def configure_optimizers(self): - return self._optimizer_type( - self.parameters(), - **self._optimizer_arguments, - ) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index f7d15441..fcdb9f95 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -5,7 +5,6 @@ import logging import typing -import lightning.pytorch as pl import torch import torch.nn import torch.optim.optimizer @@ -14,14 +13,14 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad -from .typing import Checkpoint logger = logging.getLogger(__name__) -class Densenet(pl.LightningModule): +class Densenet(Model): """Densenet-121 module. Parameters @@ -69,7 +68,14 @@ class Densenet(pl.LightningModule): dropout: float = 0.1, num_classes: int = 1, ): - super().__init__() + super().__init__( + train_loss, + validation_loss, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) self.name = "densenet-121" self.num_classes = num_classes @@ -80,17 +86,6 @@ class Densenet(pl.LightningModule): RGB(), ] - self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) - self._optimizer_type = optimizer_type - self._optimizer_arguments = optimizer_arguments - - self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms, - ) - self.pretrained = pretrained # Load pretrained model @@ -112,36 +107,6 @@ class Densenet(pl.LightningModule): x = self.normalizer(x) # type: ignore return self.model_ft(x) - def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during checkpoint saving (called by lightning). - - Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. Use on_load_checkpoint() to - restore what additional data is saved here. - - Parameters - ---------- - checkpoint - The checkpoint to save. - """ - - checkpoint["normalizer"] = self.normalizer - - def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during model loading (called by lightning). - - If you saved something with on_save_checkpoint() this is your chance to - restore this. - - Parameters - ---------- - checkpoint - The loaded checkpoint. - """ - - logger.info("Restoring normalizer from checkpoint.") - self.normalizer = checkpoint["normalizer"] - def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initialize the normalizer for the current model. @@ -205,9 +170,3 @@ class Densenet(pl.LightningModule): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) - - def configure_optimizers(self): - return self._optimizer_type( - self.parameters(), - **self._optimizer_arguments, - ) diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 16a71f73..389eac8c 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -5,7 +5,6 @@ import logging import typing -import lightning.pytorch as pl import torch import torch.nn import torch.nn.functional as F # noqa: N812 @@ -14,14 +13,14 @@ import torch.utils.data import torchvision.transforms from ..data.typing import TransformSequence +from .model import Model from .separate import separate from .transforms import Grayscale, SquareCenterPad -from .typing import Checkpoint logger = logging.getLogger(__name__) -class Pasa(pl.LightningModule): +class Pasa(Model): """Implementation of CNN by Pasa and others. Simple CNN for classification based on paper by [PASA-2019]_. @@ -67,7 +66,14 @@ class Pasa(pl.LightningModule): augmentation_transforms: TransformSequence = [], num_classes: int = 1, ): - super().__init__() + super().__init__( + train_loss, + validation_loss, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) self.name = "pasa" self.num_classes = num_classes @@ -82,17 +88,6 @@ class Pasa(pl.LightningModule): ), ] - self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) - self._optimizer_type = optimizer_type - self._optimizer_arguments = optimizer_arguments - - self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms, - ) - # First convolution block self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) @@ -213,53 +208,6 @@ class Pasa(pl.LightningModule): # x = F.log_softmax(x, dim=1) # 0 is batch size - def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during checkpoint saving (called by lightning). - - Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. Use on_load_checkpoint() to - restore what additional data is saved here. - - Parameters - ---------- - checkpoint - The checkpoint to save. - """ - - checkpoint["normalizer"] = self.normalizer - - def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during model loading (called by lightning). - - If you saved something with on_save_checkpoint() this is your chance to - restore this. - - Parameters - ---------- - checkpoint - The loaded checkpoint. - """ - - logger.info("Restoring normalizer from checkpoint.") - self.normalizer = checkpoint["normalizer"] - - def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """Initialize the input normalizer for the current model. - - Parameters - ---------- - dataloader - A torch Dataloader from which to compute the mean and std. - """ - - from .normalizer import make_z_normalizer - - logger.info( - f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader.", - ) - self.normalizer = make_z_normalizer(dataloader) - def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] @@ -292,9 +240,3 @@ class Pasa(pl.LightningModule): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) - - def configure_optimizers(self): - return self._optimizer_type( - self.parameters(), - **self._optimizer_arguments, - ) -- GitLab