diff --git a/doc/api.rst b/doc/api.rst index 186f9eeab1aebde1081615f24fabd9971a70777d..f493d80b319435caabdd2e9f409c6bdcf0c2102f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -45,6 +45,7 @@ CNN and other models implemented. mednet.models.logistic_regression mednet.models.loss_weights mednet.models.mlp + mednet.models.model mednet.models.normalizer mednet.models.separate mednet.models.transforms diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py index 9703f964a476b53d5aa076242bb0b02cedfe75ff..7f28186750e1e21d1f801cbb7d21d17da5ca2011 100644 --- a/src/mednet/config/models/alexnet.py +++ b/src/mednet/config/models/alexnet.py @@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py index 8887db8f6f006cd2580dabac44202055a5cdacab..a935655555a004cbe3c5b8e2b19f77458e952e40 100644 --- a/src/mednet/config/models/alexnet_pretrained.py +++ b/src/mednet/config/models/alexnet_pretrained.py @@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py index f28dd23cd12c72e5fc6713e706f0e9c05158759c..9ee510ac8df93713b995f857bf5afe2cb68b89a6 100644 --- a/src/mednet/config/models/densenet.py +++ b/src/mednet/config/models/densenet.py @@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py index 274a564601094a8ecb51e67c87f19f1f8197a30a..b7e2efcdfa83e1b70a466dbca0ddca02cf4695dc 100644 --- a/src/mednet/config/models/densenet_pretrained.py +++ b/src/mednet/config/models/densenet_pretrained.py @@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py index e7db48850d0e8d2b959b39ee93bae3b78dccfa80..813bb76cf92b3abe105e7095085d7e01de4fbecd 100644 --- a/src/mednet/config/models/densenet_rs.py +++ b/src/mednet/config/models/densenet_rs.py @@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), augmentation_transforms=[ElasticDeformation(p=0.2)], diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py index 227b9b426568bebf327c1ac3f206a5fc3f2b44b6..7787d10e32cfad9ece6d42cee8be6bc0bb86124f 100644 --- a/src/mednet/config/models/pasa.py +++ b/src/mednet/config/models/pasa.py @@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation from mednet.models.pasa import Pasa model = Pasa( - train_loss=BCEWithLogitsLoss(), - validation_loss=BCEWithLogitsLoss(), + loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=8e-5), augmentation_transforms=[ElasticDeformation(p=0.8)], diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index c73bb27932c52e2061134a7f007256c7f6292161..6c7d759f4a8890f576af327450ac89e144f6fbfa 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -481,10 +481,6 @@ class ConcatDataModule(lightning.LightningDataModule): for CPU memory. Sufficient CPU memory must be available before you set this attribute to ``True``. It is typically useful for relatively small datasets. - balance_sampler_by_class - If set, then modifies the random sampler used during training and - validation to balance sample picking probability, making sample - across classes **and** datasets equitable. batch_size Number of samples in every **training** batch (this parameter affects memory requirements for the network). If the number of samples in the @@ -529,7 +525,6 @@ class ConcatDataModule(lightning.LightningDataModule): database_name: str = "", split_name: str = "", cache_samples: bool = False, - balance_sampler_by_class: bool = False, batch_size: int = 1, batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, @@ -552,7 +547,6 @@ class ConcatDataModule(lightning.LightningDataModule): self.cache_samples = cache_samples self._train_sampler = None - self.balance_sampler_by_class = balance_sampler_by_class self._model_transforms: list[Transform] | None = None @@ -667,40 +661,6 @@ class ConcatDataModule(lightning.LightningDataModule): ) self._datasets = {} - @property - def balance_sampler_by_class(self) -> bool: - """Whether to balance samples across labels/datasets. - - If set, then modifies the random sampler used during training - and validation to balance sample picking probability, making - sample across classes **and** datasets equitable. - - .. warning:: - - This method does **NOT** balance the sampler per dataset, in case - multiple datasets compose the same training set. It only balances - samples acording to their ground-truth (labels). If you'd like to - have samples balanced per dataset, then implement your own data - module inheriting from this one. - - Returns - ------- - bool - True if self._train_sample is set, else False. - """ - return self._train_sampler is not None - - @balance_sampler_by_class.setter - def balance_sampler_by_class(self, value: bool): - if value: - if "train" not in self._datasets: - self._setup_dataset("train") - self._train_sampler = _make_balanced_random_sampler( - self._datasets["train"], - ) - else: - self._train_sampler = None - def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: """Coherently set the batch-chunk-size after validation. @@ -798,7 +758,7 @@ class ConcatDataModule(lightning.LightningDataModule): else: self._datasets[name] = _ConcatDataset(datasets) - def _val_dataset_keys(self) -> list[str]: + def val_dataset_keys(self) -> list[str]: """Return list of validation dataset names. Returns @@ -836,11 +796,11 @@ class ConcatDataModule(lightning.LightningDataModule): """ if stage == "fit": - for k in ["train"] + self._val_dataset_keys(): + for k in ["train"] + self.val_dataset_keys(): self._setup_dataset(k) elif stage == "validate": - for k in self._val_dataset_keys(): + for k in self.val_dataset_keys(): self._setup_dataset(k) elif stage == "test": @@ -929,7 +889,7 @@ class ConcatDataModule(lightning.LightningDataModule): self._datasets[k], **validation_loader_opts, ) - for k in self._val_dataset_keys() + for k in self.val_dataset_keys() } def test_dataloader(self) -> dict[str, DataLoader]: diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py index 4c19ac55e5d325c6506dff393f70f0503e0d5682..b1b16f65ddca03d1d4fdd5a980999fca6aa18349 100644 --- a/src/mednet/engine/callbacks.py +++ b/src/mednet/engine/callbacks.py @@ -374,4 +374,5 @@ class LoggingCallback(lightning.pytorch.Callback): on_step=False, on_epoch=True, batch_size=batch[0].shape[0], + add_dataloader_idx=False, ) diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 23df024bcf5efbc73c2c2b6bc207df3d5889a77e..5ea8ccae4cfecb6987fb7affca1274e6bb474a4e 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -72,6 +72,8 @@ def run( output_folder.mkdir(parents=True, exist_ok=True) + model.configure_losses() + from .loggers import CustomTensorboardLogger log_dir = "logs" diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index b4b9e723ad2a327f56e3051522654c3b6166ac89..75223c9a4b78e196d0c81dab20b566c4b5d32d4b 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -5,7 +5,6 @@ import logging import typing -import lightning.pytorch as pl import torch import torch.nn import torch.optim.optimizer @@ -14,36 +13,29 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad -from .typing import Checkpoint logger = logging.getLogger(__name__) -class Alexnet(pl.LightningModule): +class Alexnet(Model): """Alexnet module. Note: only usable with a normalized dataset Parameters ---------- - train_loss - The loss to be used during the training. - - .. warning:: - - The loss should be set to always return batch averages (as opposed - to the batch sum), as our logging system expects it so. - validation_loss - The loss to be used for validation (may be different from the training - loss). If extra-validation sets are provided, the same loss will be - used throughout. + loss_type + The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -60,15 +52,22 @@ class Alexnet(pl.LightningModule): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], pretrained: bool = False, num_classes: int = 1, ): - super().__init__() + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) self.name = "alexnet" self.num_classes = num_classes @@ -79,17 +78,6 @@ class Alexnet(pl.LightningModule): RGB(), ] - self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) - self._optimizer_type = optimizer_type - self._optimizer_arguments = optimizer_arguments - - self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms, - ) - self.pretrained = pretrained # Load pretrained model @@ -109,36 +97,6 @@ class Alexnet(pl.LightningModule): x = self.normalizer(x) # type: ignore return self.model_ft(x) - def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during checkpoint saving (called by lightning). - - Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. Use on_load_checkpoint() to - restore what additional data is saved here. - - Parameters - ---------- - checkpoint - The checkpoint to save. - """ - - checkpoint["normalizer"] = self.normalizer - - def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during model loading (called by lightning). - - If you saved something with on_save_checkpoint() this is your chance to - restore this. - - Parameters - ---------- - checkpoint - The loaded checkpoint. - """ - - logger.info("Restoring normalizer from checkpoint.") - self.normalizer = checkpoint["normalizer"] - def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initialize the normalizer for the current model. @@ -201,16 +159,9 @@ class Alexnet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) - - def configure_optimizers(self): - return self._optimizer_type( - self.parameters(), - **self._optimizer_arguments, - ) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index f7d1544164cf17c6f33f123d6d4bd02435722eb5..76df1ed64a7a73e99b86d18145169dd600601044 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -5,7 +5,6 @@ import logging import typing -import lightning.pytorch as pl import torch import torch.nn import torch.optim.optimizer @@ -14,34 +13,27 @@ import torchvision.models as models import torchvision.transforms from ..data.typing import TransformSequence +from .model import Model from .separate import separate from .transforms import RGB, SquareCenterPad -from .typing import Checkpoint logger = logging.getLogger(__name__) -class Densenet(pl.LightningModule): +class Densenet(Model): """Densenet-121 module. Parameters ---------- - train_loss - The loss to be used during the training. - - .. warning:: - - The loss should be set to always return batch averages (as opposed - to the batch sum), as our logging system expects it so. - validation_loss - The loss to be used for validation (may be different from the training - loss). If extra-validation sets are provided, the same loss will be - used throughout. + loss_type + The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -60,8 +52,8 @@ class Densenet(pl.LightningModule): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], @@ -69,7 +61,14 @@ class Densenet(pl.LightningModule): dropout: float = 0.1, num_classes: int = 1, ): - super().__init__() + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) self.name = "densenet-121" self.num_classes = num_classes @@ -80,17 +79,6 @@ class Densenet(pl.LightningModule): RGB(), ] - self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) - self._optimizer_type = optimizer_type - self._optimizer_arguments = optimizer_arguments - - self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms, - ) - self.pretrained = pretrained # Load pretrained model @@ -112,36 +100,6 @@ class Densenet(pl.LightningModule): x = self.normalizer(x) # type: ignore return self.model_ft(x) - def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during checkpoint saving (called by lightning). - - Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. Use on_load_checkpoint() to - restore what additional data is saved here. - - Parameters - ---------- - checkpoint - The checkpoint to save. - """ - - checkpoint["normalizer"] = self.normalizer - - def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during model loading (called by lightning). - - If you saved something with on_save_checkpoint() this is your chance to - restore this. - - Parameters - ---------- - checkpoint - The loaded checkpoint. - """ - - logger.info("Restoring normalizer from checkpoint.") - self.normalizer = checkpoint["normalizer"] - def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initialize the normalizer for the current model. @@ -205,9 +163,3 @@ class Densenet(pl.LightningModule): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) - - def configure_optimizers(self): - return self._optimizer_type( - self.parameters(), - **self._optimizer_arguments, - ) diff --git a/src/mednet/models/loss_weights.py b/src/mednet/models/loss_weights.py index bf965790cade10d68e19c0bf372c9fa7bf4d5409..d04bdfea67ea391b7a18a066b6aa2b563a345aea 100644 --- a/src/mednet/models/loss_weights.py +++ b/src/mednet/models/loss_weights.py @@ -3,87 +3,180 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging +import typing +from collections import Counter import torch import torch.utils.data -from ..data.typing import DataLoader - logger = logging.getLogger(__name__) -def _get_label_weights( - dataloader: torch.utils.data.DataLoader, -) -> torch.Tensor: - """Compute the weights of each class of a DataLoader. +def compute_binary_weights(targets): + """Compute the positive weights when using binary targets. - This function inputs a pytorch DataLoader and computes the ratio between - number of negative and positive samples (scalar). The weight can be used - to adjust minimisation criteria to in cases there is a huge data imbalance. + Parameters + ---------- + targets + A tensor of integer values of length n. - It returns a vector with weights (inverse counts) for each label. + Returns + ------- + The positive weights per class. + """ + class_sample_count = [ + float((targets == t).sum().item()) + for t in torch.unique(targets, sorted=True) + ] + + # Divide negatives by positives + return torch.tensor( + [class_sample_count[0] / class_sample_count[1]], + ).reshape(-1) + + +def compute_multiclass_weights(targets): + """Compute the positive weights when using exclusive, multiclass targets. Parameters ---------- - dataloader - A DataLoader from which to compute the positive weights. Entries must - be a dictionary which must contain a ``label`` key. + targets + A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples. Returns ------- - torch.Tensor - The positive weight of each class in the dataset given as input. + The positive weights per class. """ - targets = torch.tensor( - [sample for batch in dataloader for sample in batch[1]["label"]], + class_sample_count = torch.sum(targets, dim=1) + negative_class_sample_count = ( + torch.full((targets.size()[0],), float(targets.size()[1])) + - class_sample_count ) - # Binary labels - if len(list(targets.shape)) == 1: - class_sample_count = [ - float((targets == t).sum().item()) - for t in torch.unique(targets, sorted=True) - ] + return negative_class_sample_count / ( + class_sample_count + negative_class_sample_count + ) - # Divide negatives by positives - positive_weights = torch.tensor( - [class_sample_count[0] / class_sample_count[1]], - ).reshape(-1) - # Multiclass labels - else: - class_sample_count = torch.sum(targets, dim=0) - negative_class_sample_count = ( - torch.full((targets.size()[1],), float(targets.size()[0])) - - class_sample_count - ) +def compute_non_exclusive_multiclass_weights(targets): + """Compute the positive weights when using non-exclusive, multiclass targets. - positive_weights = negative_class_sample_count / ( - class_sample_count + negative_class_sample_count - ) + Parameters + ---------- + targets + A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples. - return positive_weights + Returns + ------- + The positive weights per class. + """ + raise ValueError( + "Computing weights of multi-class, non-exclusive labels is not yet supported." + ) + + +def is_multicalss_exclusive(targets: torch.Tensor) -> bool: + """Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class. + + Parameters + ---------- + targets + A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples. + + Returns + ------- + True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes). + """ + max_counts = [] + transposed_targets = torch.transpose(targets, 0, 1) + for t in transposed_targets: + filtered_list = [i for i in t.tolist() if i != 2] + counts = Counter(filtered_list) + max_counts.append(max(counts.values())) + + if set(max_counts) == {1}: + return True + + return False -def make_balanced_bcewithlogitsloss( - dataloader: DataLoader, -) -> torch.nn.BCEWithLogitsLoss: - """Return a balanced binary-cross-entropy loss. +def tensor_to_list(tensor) -> list[typing.Any]: + """Convert a torch.Tensor to a list. - The loss is weighted using the ratio between positives and total examples - available. + This is necessary, as torch.tolist returns an int when then tensor contains a single value. + + Parameters + ---------- + tensor + The tensor to convert to a list. + + Returns + ------- + The tensor converted to a list. + """ + + tensor = tensor.tolist() + if isinstance(tensor, int): + return [tensor] + return tensor + + +def get_positive_weights( + dataloader: torch.utils.data.DataLoader, +) -> torch.Tensor: + """Compute the weights of each class of a DataLoader. + + This function inputs a pytorch DataLoader and computes the ratio between + number of negative and positive samples (scalar). The weight can be used + to adjust minimisation criteria to in cases there is a huge data imbalance. + + It returns a vector with weights (inverse counts) for each label. Parameters ---------- dataloader - The DataLoader to use to compute the BCE weights. + A DataLoader from which to compute the positive weights. Entries must + be a dictionary which must contain a ``label`` key. Returns ------- - torch.nn.BCEWithLogitsLoss - An instance of the weighted loss. + The positive weight of each class in the dataset given as input. """ - weights = _get_label_weights(dataloader) - return torch.nn.BCEWithLogitsLoss(pos_weight=weights) + from collections import defaultdict + + targets = defaultdict(list) + + for batch in dataloader: + for class_idx, class_targets in enumerate(batch[1]["label"]): + # Targets are either a single tensor (binary case) or a list of tensors (multilabel) + if isinstance(batch[1]["label"], list): + targets[class_idx].extend(tensor_to_list(class_targets)) + else: + targets[0].extend(tensor_to_list(class_targets)) + + targets_list = [] + for k in sorted(list(targets.keys())): + targets_list.append(targets[k]) + + targets_tensor = torch.tensor(targets_list) + + if targets_tensor.shape[0] == 1: + logger.info("Computing positive weights assuming binary labels.") + positive_weights = compute_binary_weights(targets_tensor) + else: + if is_multicalss_exclusive(targets_tensor): + logger.info( + "Computing positive weights assuming multiclass, exclusive labels." + ) + positive_weights = compute_multiclass_weights(targets_tensor) + else: + logger.info( + "Computing positive weights assuming multiclass, non-exclusive labels." + ) + positive_weights = compute_non_exclusive_multiclass_weights( + targets_tensor + ) + + return positive_weights diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..57c1b618b1c0943496a4e799552989c03f2e3532 --- /dev/null +++ b/src/mednet/models/model.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import typing + +import lightning.pytorch as pl +import torch +import torch.nn +import torch.optim.optimizer +import torch.utils.data +import torchvision.transforms + +from ..data.typing import TransformSequence +from .loss_weights import get_positive_weights +from .typing import Checkpoint + +logger = logging.getLogger(__name__) + + +class Model(pl.LightningModule): + """Base class for models. + + Parameters + ---------- + loss_type + The loss to be used for training and evaluation. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + """ + + def __init__( + self, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + ): + super().__init__() + + self.name = "model" + self.num_classes = num_classes + + self.model_transforms: TransformSequence = [] + + self._loss_type = loss_type + + self._train_loss = None + self._train_loss_arguments = loss_arguments + + self.validation_loss = None + self._validation_loss_arguments = loss_arguments + + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments + + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms, + ) + + def forward(self, x): + raise NotImplementedError + + def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: + """Perform actions during checkpoint saving (called by lightning). + + Called by Lightning when saving a checkpoint to give you a chance to + store anything else you might want to save. Use on_load_checkpoint() to + restore what additional data is saved here. + + Parameters + ---------- + checkpoint + The checkpoint to save. + """ + + checkpoint["normalizer"] = self.normalizer + + def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: + """Perform actions during model loading (called by lightning). + + If you saved something with on_save_checkpoint() this is your chance to + restore this. + + Parameters + ---------- + checkpoint + The loaded checkpoint. + """ + + logger.info("Restoring normalizer from checkpoint.") + self.normalizer = checkpoint["normalizer"] + + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initialize the input normalizer for the current model. + + Parameters + ---------- + dataloader + A torch Dataloader from which to compute the mean and std. + """ + + from .normalizer import make_z_normalizer + + logger.info( + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader.", + ) + self.normalizer = make_z_normalizer(dataloader) + + def training_step(self, batch, _): + raise NotImplementedError + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + raise NotImplementedError + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + raise NotImplementedError + + def configure_losses(self): + self._train_loss = self._loss_type(**self._train_loss_arguments) + self._validation_loss = self._loss_type( + **self._validation_loss_arguments + ) + + def configure_optimizers(self): + return self._optimizer_type( + self.parameters(), + **self._optimizer_arguments, + ) + + def balance_losses(self, datamodule) -> None: + """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute). + + Parameters + ---------- + datamodule + Instance of a datamodule. + """ + + try: + getattr(self._loss_type(), "pos_weight") + except AttributeError: + logger.warning( + f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced." + ) + else: + logger.info(f"Balancing training loss {self._loss_type}.") + train_weights = get_positive_weights(datamodule.train_dataloader()) + self._train_loss_arguments["pos_weight"] = train_weights + + logger.info(f"Balancing validation loss {self._loss_type}.") + validation_weights = get_positive_weights( + datamodule.val_dataloader()["validation"] + ) + self._validation_loss_arguments["pos_weight"] = validation_weights diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 16a71f73c93bc9ab4dd8a65d2c9da7d96d27a36e..e9147683b08f8d8b532396544eece7829c23fd92 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -5,7 +5,6 @@ import logging import typing -import lightning.pytorch as pl import torch import torch.nn import torch.nn.functional as F # noqa: N812 @@ -14,14 +13,14 @@ import torch.utils.data import torchvision.transforms from ..data.typing import TransformSequence +from .model import Model from .separate import separate from .transforms import Grayscale, SquareCenterPad -from .typing import Checkpoint logger = logging.getLogger(__name__) -class Pasa(pl.LightningModule): +class Pasa(Model): """Implementation of CNN by Pasa and others. Simple CNN for classification based on paper by [PASA-2019]_. @@ -31,22 +30,15 @@ class Pasa(pl.LightningModule): Parameters ---------- - train_loss - The loss to be used during the training. - - .. warning:: - - The loss should be set to always return batch averages (as opposed - to the batch sum), as our logging system expects it so. - validation_loss - The loss to be used for validation (may be different from the training - loss). If extra-validation sets are provided, the same loss will be - used throughout. + loss_type + The loss to be used for training and evaluation. .. warning:: The loss should be set to always return batch averages (as opposed to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. optimizer_type The type of optimizer to use for training. optimizer_arguments @@ -60,14 +52,21 @@ class Pasa(pl.LightningModule): def __init__( self, - train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), - validation_loss: torch.nn.Module | None = None, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], num_classes: int = 1, ): - super().__init__() + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) self.name = "pasa" self.num_classes = num_classes @@ -82,17 +81,6 @@ class Pasa(pl.LightningModule): ), ] - self._train_loss = train_loss - self._validation_loss = ( - validation_loss if validation_loss is not None else train_loss - ) - self._optimizer_type = optimizer_type - self._optimizer_arguments = optimizer_arguments - - self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms, - ) - # First convolution block self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) @@ -213,53 +201,6 @@ class Pasa(pl.LightningModule): # x = F.log_softmax(x, dim=1) # 0 is batch size - def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during checkpoint saving (called by lightning). - - Called by Lightning when saving a checkpoint to give you a chance to - store anything else you might want to save. Use on_load_checkpoint() to - restore what additional data is saved here. - - Parameters - ---------- - checkpoint - The checkpoint to save. - """ - - checkpoint["normalizer"] = self.normalizer - - def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: - """Perform actions during model loading (called by lightning). - - If you saved something with on_save_checkpoint() this is your chance to - restore this. - - Parameters - ---------- - checkpoint - The loaded checkpoint. - """ - - logger.info("Restoring normalizer from checkpoint.") - self.normalizer = checkpoint["normalizer"] - - def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """Initialize the input normalizer for the current model. - - Parameters - ---------- - dataloader - A torch Dataloader from which to compute the mean and std. - """ - - from .normalizer import make_z_normalizer - - logger.info( - f"Uninitialised {self.name} model - " - f"computing z-norm factors from train dataloader.", - ) - self.normalizer = make_z_normalizer(dataloader) - def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] @@ -285,16 +226,9 @@ class Pasa(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) - - def configure_optimizers(self): - return self._optimizer_type( - self.parameters(), - **self._optimizer_arguments, - ) diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 68a4e7e721f412fbed430774491147f9c1130577..e83ae359b0be523565f669c57bf1de41520ccac7 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -296,10 +296,8 @@ def train( # of class examples available in the training set. Also affects the # validation loss if a validation set is available on the DataModule. if balance_classes: - logger.info("Applying DataModule train sampler balancing...") - datamodule.balance_sampler_by_class = True - # logger.info("Applying train/valid loss balancing...") - # model.balance_losses_by_class(datamodule) + logger.info("Applying train/valid loss balancing...") + model.balance_losses(datamodule) else: logger.info( "Skipping sample class/dataset ownership balancing on user request", diff --git a/tests/test_cli.py b/tests/test_cli.py index 4ee0e2c6ab454161dbf871030b30b989593d1c85..a512d6add392dc09439daaca7e0d8f9b01fbc443 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -241,8 +241,7 @@ def test_train_pasa_montgomery(temporary_basedir): keywords = { r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, - r"^Applying DataModule train sampler balancing...$": 1, - r"^Balancing samples from dataset using metadata targets `label`$": 1, + r"^Applying train/valid loss balancing...$": 1, r"^Training for at most 1 epochs.$": 1, r"^Uninitialised pasa model - computing z-norm factors from train dataloader.$": 1, r"^Writing run metadata at.*$": 1, @@ -323,8 +322,7 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): keywords = { r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, - r"^Applying DataModule train sampler balancing...$": 1, - r"^Balancing samples from dataset using metadata targets `label`$": 1, + r"^Applying train/valid loss balancing...$": 1, r"^Training for at most 2 epochs.$": 1, r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, r"^Writing run metadata at.*$": 1,