From 7eaac22c67c4063c5f2c1d633ecdf0598b2cb956 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 11 Jul 2023 10:58:26 +0200 Subject: [PATCH] Updated models --- src/ptbench/data/datamodule.py | 4 +- src/ptbench/models/alexnet.py | 117 ++++++++++++++++++++++----------- src/ptbench/models/densenet.py | 108 +++++++++++++++++++----------- src/ptbench/models/pasa.py | 46 ++++++++++++- 4 files changed, 194 insertions(+), 81 deletions(-) diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 1ae9531e..abcf11d7 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -229,11 +229,11 @@ def _make_balanced_random_sampler( 1. The probability of picking a sample from any target is the same (0.5 in this case). To verify this, notice that the probability of picking a sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. - 2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is + 2. The probability of picking a sample with ``target=0`` from Dataset 2 is 3 times higher than those from Dataset 1. As there are 3 times less samples in Dataset 2 with ``target=0``, this makes choosing samples from Dataset 1 proportionally less likely. - 3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is + 3. The probability of picking a sample with ``target=1`` from Dataset 2 is 3 times lower than those from Dataset 1. As there are 3 times less samples in Dataset 1 with ``target=1``, this makes choosing samples from Dataset 2 proportionally less likely. diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index e8643b46..a878a076 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -3,14 +3,18 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging +import typing import lightning.pytorch as pl -import torch.nn as nn +import torch +import torch.nn +import torch.nn.functional as F +import torch.optim.optimizer import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import DataLoader +from ..data.typing import DataLoader, TransformSequence logger = logging.getLogger(__name__) @@ -19,32 +23,66 @@ class Alexnet(pl.LightningModule): """Alexnet module. Note: only usable with a normalized dataset + + Parameters + ---------- + + train_loss + The loss to be used during the training. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + validation_loss + The loss to be used for validation (may be different from the training + loss). If extra-validation sets are provided, the same loss will be + used throughout. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + optimizer_type + The type of optimizer to use for training + + optimizer_arguments + Arguments to the optimizer after ``params``. + + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + + pretrained + If set to True, loads pretrained model weights during initialization, else trains a new model. """ def __init__( self, - criterion=None, - criterion_valid=None, - optimizer=None, - optimizer_configs=None, - pretrained=False, - augmentation_transforms=[], + train_loss: torch.nn.Module, + validation_loss: torch.nn.Module | None, + optimizer_type: type[torch.optim.Optimizer], + optimizer_arguments: dict[str, typing.Any], + augmentation_transforms: TransformSequence = [], + pretrained: bool = False, ): super().__init__() self.name = "alexnet" - self.augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms + self._train_loss = train_loss + self._validation_loss = ( + validation_loss if validation_loss is not None else train_loss ) + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments - self.criterion = criterion - self.criterion_valid = criterion_valid - - self.optimizer = optimizer - self.optimizer_configs = optimizer_configs + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) - self.normalizer = None self.pretrained = pretrained # Load pretrained model @@ -57,11 +95,12 @@ class Alexnet(pl.LightningModule): self.model_ft = models.alexnet(weights=weights) # Adapt output features - self.model_ft.classifier[4] = nn.Linear(4096, 512) - self.model_ft.classifier[6] = nn.Linear(512, 1) + self.model_ft.classifier[4] = torch.nn.Linear(4096, 512) + self.model_ft.classifier[6] = torch.nn.Linear(512, 1) def forward(self, x): - x = self.normalizer(x) + x = self.normalizer(x) # type: ignore + x = self.model_ft(x) return x @@ -121,25 +160,25 @@ class Alexnet(pl.LightningModule): """ from .loss_weights import get_label_weights - if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss training criterion.") + if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training loss.") weights = get_label_weights(train_dataloader) - self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights) + self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) else: raise RuntimeError( "Training loss is not BCEWithLogitsLoss - dunno how to balance" ) - if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss validation criterion.") + if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation loss.") weights = get_label_weights(valid_dataloader) - self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights) + self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) else: raise RuntimeError( "Validation loss is not BCEWithLogitsLoss - dunno how to balance" ) - def training_step(self, batch, batch_idx): + def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] @@ -150,15 +189,13 @@ class Alexnet(pl.LightningModule): # Forward pass on the network augmented_images = [ - self.augmentation_transforms(img).to(self.device) for img in images + self._augmentation_transforms(img).to(self.device) for img in images ] # Combine list of augmented images back into a tensor augmented_images = torch.cat(augmented_images, 0).view(images.shape) outputs = self(augmented_images) - training_loss = self.criterion(outputs, labels.float()) - - return {"loss": training_loss} + return self._train_loss(outputs, labels.float()) def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[0] @@ -172,23 +209,23 @@ class Alexnet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.float()) - - if dataloader_idx == 0: - return {"validation_loss": validation_loss} - else: - return {f"extra_validation_loss_{dataloader_idx}": validation_loss} + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): images = batch[0] labels = batch[1]["label"] - names = batch[1]["name"] + names = batch[1]["names"] outputs = self(images) probabilities = torch.sigmoid(outputs) - return names[0], torch.flatten(probabilities), torch.flatten(labels) + return ( + names[0], + torch.flatten(probabilities), + torch.flatten(labels), + ) def configure_optimizers(self): - optimizer = self.optimizer(self.parameters(), **self.optimizer_configs) - return optimizer + return self._optimizer_type( + self.parameters(), **self._optimizer_arguments + ) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 25d1d8ff..8eba3b53 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -3,15 +3,18 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging +import typing import lightning.pytorch as pl import torch import torch.nn +import torch.nn.functional as F +import torch.optim.optimizer import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import DataLoader +from ..data.typing import DataLoader, TransformSequence logger = logging.getLogger(__name__) @@ -22,32 +25,61 @@ class Densenet(pl.LightningModule): Parameters ---------- - criterion - A dictionary containing the criteria for the + train_loss + The loss to be used during the training. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + validation_loss + The loss to be used for validation (may be different from the training + loss). If extra-validation sets are provided, the same loss will be + used throughout. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + optimizer_type + The type of optimizer to use for training + + optimizer_arguments + Arguments to the optimizer after ``params``. + + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + + pretrained + If set to True, loads pretrained model weights during initialization, else trains a new model. """ def __init__( self, - criterion=None, - criterion_valid=None, - optimizer=None, - optimizer_configs=None, - pretrained=False, - augmentation_transforms=[], + train_loss: torch.nn.Module, + validation_loss: torch.nn.Module | None, + optimizer_type: type[torch.optim.Optimizer], + optimizer_arguments: dict[str, typing.Any], + augmentation_transforms: TransformSequence = [], + pretrained: bool= False, ): super().__init__() self.name = "densenet-121" - self.augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms + self._train_loss = train_loss + self._validation_loss = ( + validation_loss if validation_loss is not None else train_loss ) + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments - self.criterion = criterion - self.criterion_valid = criterion_valid - - self.optimizer = optimizer - self.optimizer_configs = optimizer_configs + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) self.pretrained = pretrained @@ -66,7 +98,9 @@ class Densenet(pl.LightningModule): ) def forward(self, x): - x = self.normalizer(x) + + x = self.normalizer(x) # type: ignore + x = self.model_ft(x) return x @@ -128,25 +162,25 @@ class Densenet(pl.LightningModule): """ from .loss_weights import get_label_weights - if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss training criterion.") + if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training loss.") weights = get_label_weights(train_dataloader) - self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights) + self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) else: raise RuntimeError( "Training loss is not BCEWithLogitsLoss - dunno how to balance" ) - if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss validation criterion.") + if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation loss.") weights = get_label_weights(valid_dataloader) - self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights) + self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) else: raise RuntimeError( "Validation loss is not BCEWithLogitsLoss - dunno how to balance" ) - def training_step(self, batch, batch_idx): + def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] @@ -157,15 +191,13 @@ class Densenet(pl.LightningModule): # Forward pass on the network augmented_images = [ - self.augmentation_transforms(img).to(self.device) for img in images + self._augmentation_transforms(img).to(self.device) for img in images ] # Combine list of augmented images back into a tensor augmented_images = torch.cat(augmented_images, 0).view(images.shape) outputs = self(augmented_images) - training_loss = self.criterion(outputs, labels.float()) - - return {"loss": training_loss} + return self._train_loss(outputs, labels.float()) def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[0] @@ -179,23 +211,23 @@ class Densenet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.float()) - - if dataloader_idx == 0: - return {"validation_loss": validation_loss} - else: - return {f"extra_validation_loss_{dataloader_idx}": validation_loss} + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): images = batch[0] labels = batch[1]["label"] - names = batch[1]["name"] + names = batch[1]["names"] outputs = self(images) probabilities = torch.sigmoid(outputs) - return names[0], torch.flatten(probabilities), torch.flatten(labels) + return ( + names[0], + torch.flatten(probabilities), + torch.flatten(labels), + ) def configure_optimizers(self): - optimizer = self.optimizer(self.parameters(), **self.optimizer_configs) - return optimizer + return self._optimizer_type( + self.parameters(), **self._optimizer_arguments + ) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 34e5c67f..20bbb0dd 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -13,7 +13,7 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms -from ..data.typing import TransformSequence +from ..data.typing import DataLoader, TransformSequence logger = logging.getLogger(__name__) @@ -202,6 +202,50 @@ class Pasa(pl.LightningModule): ) self.normalizer = make_z_normalizer(dataloader) + def balance_losses_by_class( + self, + train_dataloader: DataLoader, + valid_dataloader: dict[str, DataLoader], + ): + """Reweights loss weights if possible. + + Parameters + ---------- + + train_dataloader + The data loader to use for training + + valid_dataloader + The data loaders to use for each of the validation sets + + + Raises + ------ + + RuntimeError + If train or validation losses are not of type + :py:class:`torch.nn.BCEWithLogitsLoss`. + """ + from .loss_weights import get_label_weights + + if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training loss.") + weights = get_label_weights(train_dataloader) + self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) + else: + raise RuntimeError( + "Training loss is not BCEWithLogitsLoss - dunno how to balance" + ) + + if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation loss.") + weights = get_label_weights(valid_dataloader) + self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) + else: + raise RuntimeError( + "Validation loss is not BCEWithLogitsLoss - dunno how to balance" + ) + def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] -- GitLab