From 2bd2a302ae7e82096a1985acfea7feb0d6a80e9a Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 13 Jul 2023 18:56:06 +0200 Subject: [PATCH] [models] Remove loss-weight-balancing pending on issue #6 resolution --- src/ptbench/models/alexnet.py | 50 +++------------------------- src/ptbench/models/densenet.py | 52 +++--------------------------- src/ptbench/models/loss_weights.py | 24 +++++++++++++- src/ptbench/models/pasa.py | 49 ++-------------------------- 4 files changed, 36 insertions(+), 139 deletions(-) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index c567b9b3..cd391e3d 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -13,7 +13,7 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import DataLoader, TransformSequence +from ..data.typing import TransformSequence from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -56,7 +56,8 @@ class Alexnet(pl.LightningModule): applied on the input **before** it is fed into the network. pretrained - If set to True, loads pretrained model weights during initialization, else trains a new model. + If set to True, loads pretrained model weights during initialization, + else trains a new model. """ def __init__( @@ -108,7 +109,8 @@ class Alexnet(pl.LightningModule): def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: """Called by Lightning to restore your model. - If you saved something with on_save_checkpoint() this is your chance to restore this. + If you saved something with on_save_checkpoint() this is your chance to + restore this. Parameters ---------- @@ -162,48 +164,6 @@ class Alexnet(pl.LightningModule): ) self.normalizer = make_z_normalizer(dataloader) - def balance_losses_by_class( - self, train_dataloader: DataLoader, valid_dataloader: DataLoader - ): - """Reweights loss weights if possible. - - Parameters - ---------- - - train_dataloader - The data loader to use for training - - valid_dataloader - The data loader to use for validation - - - Raises - ------ - - RuntimeError - If train or validation losses are not of type - :py:class:`torch.nn.BCEWithLogitsLoss`. - """ - from .loss_weights import get_label_weights - - if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss training loss.") - weights = get_label_weights(train_dataloader) - self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) - else: - raise RuntimeError( - "Training loss is not BCEWithLogitsLoss - dunno how to balance" - ) - - if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss validation loss.") - weights = get_label_weights(valid_dataloader) - self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) - else: - raise RuntimeError( - "Validation loss is not BCEWithLogitsLoss - dunno how to balance" - ) - def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index ea2cab00..ba1d71fa 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -13,7 +13,7 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import DataLoader, TransformSequence +from ..data.typing import TransformSequence from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -54,7 +54,8 @@ class Densenet(pl.LightningModule): applied on the input **before** it is fed into the network. pretrained - If set to True, loads pretrained model weights during initialization, else trains a new model. + If set to True, loads pretrained model weights during initialization, + else trains a new model. """ def __init__( @@ -107,7 +108,8 @@ class Densenet(pl.LightningModule): def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: """Called by Lightning to restore your model. - If you saved something with on_save_checkpoint() this is your chance to restore this. + If you saved something with on_save_checkpoint() this is your chance to + restore this. Parameters ---------- @@ -161,50 +163,6 @@ class Densenet(pl.LightningModule): ) self.normalizer = make_z_normalizer(dataloader) - def balance_losses_by_class( - self, - train_dataloader: DataLoader, - valid_dataloader: dict[str, DataLoader], - ): - """Reweights loss weights if possible. - - Parameters - ---------- - - train_dataloader - The data loader to use for training - - valid_dataloader - The data loaders to use for each of the validation sets - - - Raises - ------ - - RuntimeError - If train or validation losses are not of type - :py:class:`torch.nn.BCEWithLogitsLoss`. - """ - from .loss_weights import get_label_weights - - if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss training loss.") - weights = get_label_weights(train_dataloader) - self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) - else: - raise RuntimeError( - "Training loss is not BCEWithLogitsLoss - dunno how to balance" - ) - - if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss validation loss.") - weights = get_label_weights(valid_dataloader) - self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) - else: - raise RuntimeError( - "Validation loss is not BCEWithLogitsLoss - dunno how to balance" - ) - def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] diff --git a/src/ptbench/models/loss_weights.py b/src/ptbench/models/loss_weights.py index 6889b253..8cf79b77 100644 --- a/src/ptbench/models/loss_weights.py +++ b/src/ptbench/models/loss_weights.py @@ -7,10 +7,12 @@ import logging import torch import torch.utils.data +from ..data.typing import DataLoader + logger = logging.getLogger(__name__) -def get_label_weights( +def _get_label_weights( dataloader: torch.utils.data.DataLoader, ) -> torch.Tensor: """Computes the weights of each class of a DataLoader. @@ -68,3 +70,23 @@ def get_label_weights( ) return positive_weights + + +def make_balanced_bcewithlogitsloss( + dataloader: DataLoader, +) -> torch.nn.BCEWithLogitsLoss: + """Returns a balanced binary-cross-entropy loss. + + The loss is weighted using the ratio between positives and total examples + available. + + + Returns + ------- + + loss + An instance of the weighted loss + """ + + weights = _get_label_weights(dataloader) + return torch.nn.BCEWithLogitsLoss(pos_weight=weights) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index aaa5a2b0..479ec8f2 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -13,7 +13,7 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms -from ..data.typing import DataLoader, TransformSequence +from ..data.typing import TransformSequence from .typing import Checkpoint logger = logging.getLogger(__name__) @@ -192,7 +192,8 @@ class Pasa(pl.LightningModule): def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: """Called by Lightning to restore your model. - If you saved something with on_save_checkpoint() this is your chance to restore this. + If you saved something with on_save_checkpoint() this is your chance to + restore this. Parameters ---------- @@ -232,50 +233,6 @@ class Pasa(pl.LightningModule): ) self.normalizer = make_z_normalizer(dataloader) - def balance_losses_by_class( - self, - train_dataloader: DataLoader, - valid_dataloader: dict[str, DataLoader], - ): - """Reweights loss weights if possible. - - Parameters - ---------- - - train_dataloader - The data loader to use for training - - valid_dataloader - The data loaders to use for each of the validation sets - - - Raises - ------ - - RuntimeError - If train or validation losses are not of type - :py:class:`torch.nn.BCEWithLogitsLoss`. - """ - from .loss_weights import get_label_weights - - if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss training loss.") - weights = get_label_weights(train_dataloader) - self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights) - else: - raise RuntimeError( - "Training loss is not BCEWithLogitsLoss - dunno how to balance" - ) - - if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss validation loss.") - weights = get_label_weights(valid_dataloader) - self._validation_loss = torch.nn.BCEWithLogitsLoss(weights) - else: - raise RuntimeError( - "Validation loss is not BCEWithLogitsLoss - dunno how to balance" - ) - def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] -- GitLab