Skip to content
Snippets Groups Projects
Commit 21937fb2 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[model] Create base Model class

parent 9d570ed0
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
from ..data.typing import TransformSequence
from .typing import Checkpoint
logger = logging.getLogger(__name__)
class Model(pl.LightningModule):
"""Base class for models.
Parameters
----------
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__()
self.name = "model"
self.num_classes = num_classes
self.model_transforms: TransformSequence = []
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms,
)
def forward(self, x):
raise NotImplementedError
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _):
raise NotImplementedError
def validation_step(self, batch, batch_idx, dataloader_idx=0):
raise NotImplementedError
def predict_step(self, batch, batch_idx, dataloader_idx=0):
raise NotImplementedError
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment