Skip to content
Snippets Groups Projects
Commit 5fae86b8 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[model] Use base model

parent 21937fb2
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
......@@ -5,7 +5,6 @@
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.optim.optimizer
......@@ -14,14 +13,14 @@ import torchvision.models as models
import torchvision.transforms
from ..data.typing import TransformSequence
from .model import Model
from .separate import separate
from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
class Alexnet(pl.LightningModule):
class Alexnet(Model):
"""Alexnet module.
Note: only usable with a normalized dataset
......@@ -68,7 +67,14 @@ class Alexnet(pl.LightningModule):
pretrained: bool = False,
num_classes: int = 1,
):
super().__init__()
super().__init__(
train_loss,
validation_loss,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "alexnet"
self.num_classes = num_classes
......@@ -79,17 +85,6 @@ class Alexnet(pl.LightningModule):
RGB(),
]
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms,
)
self.pretrained = pretrained
# Load pretrained model
......@@ -109,36 +104,6 @@ class Alexnet(pl.LightningModule):
x = self.normalizer(x) # type: ignore
return self.model_ft(x)
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the normalizer for the current model.
......@@ -208,9 +173,3 @@ class Alexnet(pl.LightningModule):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
......@@ -5,7 +5,6 @@
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.optim.optimizer
......@@ -14,14 +13,14 @@ import torchvision.models as models
import torchvision.transforms
from ..data.typing import TransformSequence
from .model import Model
from .separate import separate
from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
class Densenet(pl.LightningModule):
class Densenet(Model):
"""Densenet-121 module.
Parameters
......@@ -69,7 +68,14 @@ class Densenet(pl.LightningModule):
dropout: float = 0.1,
num_classes: int = 1,
):
super().__init__()
super().__init__(
train_loss,
validation_loss,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "densenet-121"
self.num_classes = num_classes
......@@ -80,17 +86,6 @@ class Densenet(pl.LightningModule):
RGB(),
]
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms,
)
self.pretrained = pretrained
# Load pretrained model
......@@ -112,36 +107,6 @@ class Densenet(pl.LightningModule):
x = self.normalizer(x) # type: ignore
return self.model_ft(x)
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the normalizer for the current model.
......@@ -205,9 +170,3 @@ class Densenet(pl.LightningModule):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
......@@ -5,7 +5,6 @@
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.nn.functional as F # noqa: N812
......@@ -14,14 +13,14 @@ import torch.utils.data
import torchvision.transforms
from ..data.typing import TransformSequence
from .model import Model
from .separate import separate
from .transforms import Grayscale, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
class Pasa(pl.LightningModule):
class Pasa(Model):
"""Implementation of CNN by Pasa and others.
Simple CNN for classification based on paper by [PASA-2019]_.
......@@ -67,7 +66,14 @@ class Pasa(pl.LightningModule):
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__()
super().__init__(
train_loss,
validation_loss,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "pasa"
self.num_classes = num_classes
......@@ -82,17 +88,6 @@ class Pasa(pl.LightningModule):
),
]
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms,
)
# First convolution block
self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
......@@ -213,53 +208,6 @@ class Pasa(pl.LightningModule):
# x = F.log_softmax(x, dim=1) # 0 is batch size
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......@@ -292,9 +240,3 @@ class Pasa(pl.LightningModule):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment