From 2bd2a302ae7e82096a1985acfea7feb0d6a80e9a Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 13 Jul 2023 18:56:06 +0200
Subject: [PATCH] [models] Remove loss-weight-balancing pending on issue #6
 resolution

---
 src/ptbench/models/alexnet.py      | 50 +++-------------------------
 src/ptbench/models/densenet.py     | 52 +++---------------------------
 src/ptbench/models/loss_weights.py | 24 +++++++++++++-
 src/ptbench/models/pasa.py         | 49 ++--------------------------
 4 files changed, 36 insertions(+), 139 deletions(-)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index c567b9b3..cd391e3d 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -13,7 +13,7 @@ import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
-from ..data.typing import DataLoader, TransformSequence
+from ..data.typing import TransformSequence
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -56,7 +56,8 @@ class Alexnet(pl.LightningModule):
         applied on the input **before** it is fed into the network.
 
     pretrained
-        If set to True, loads pretrained model weights during initialization, else trains a new model.
+        If set to True, loads pretrained model weights during initialization,
+        else trains a new model.
     """
 
     def __init__(
@@ -108,7 +109,8 @@ class Alexnet(pl.LightningModule):
     def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
         """Called by Lightning to restore your model.
 
-        If you saved something with on_save_checkpoint() this is your chance to restore this.
+        If you saved something with on_save_checkpoint() this is your chance to
+        restore this.
 
         Parameters
         ----------
@@ -162,48 +164,6 @@ class Alexnet(pl.LightningModule):
             )
             self.normalizer = make_z_normalizer(dataloader)
 
-    def balance_losses_by_class(
-        self, train_dataloader: DataLoader, valid_dataloader: DataLoader
-    ):
-        """Reweights loss weights if possible.
-
-        Parameters
-        ----------
-
-        train_dataloader
-            The data loader to use for training
-
-        valid_dataloader
-            The data loader to use for validation
-
-
-        Raises
-        ------
-
-        RuntimeError
-            If train or validation losses are not of type
-            :py:class:`torch.nn.BCEWithLogitsLoss`.
-        """
-        from .loss_weights import get_label_weights
-
-        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss training loss.")
-            weights = get_label_weights(train_dataloader)
-            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
-        else:
-            raise RuntimeError(
-                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
-            )
-
-        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
-            weights = get_label_weights(valid_dataloader)
-            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
-        else:
-            raise RuntimeError(
-                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
-            )
-
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index ea2cab00..ba1d71fa 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -13,7 +13,7 @@ import torch.utils.data
 import torchvision.models as models
 import torchvision.transforms
 
-from ..data.typing import DataLoader, TransformSequence
+from ..data.typing import TransformSequence
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -54,7 +54,8 @@ class Densenet(pl.LightningModule):
         applied on the input **before** it is fed into the network.
 
     pretrained
-        If set to True, loads pretrained model weights during initialization, else trains a new model.
+        If set to True, loads pretrained model weights during initialization,
+        else trains a new model.
     """
 
     def __init__(
@@ -107,7 +108,8 @@ class Densenet(pl.LightningModule):
     def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
         """Called by Lightning to restore your model.
 
-        If you saved something with on_save_checkpoint() this is your chance to restore this.
+        If you saved something with on_save_checkpoint() this is your chance to
+        restore this.
 
         Parameters
         ----------
@@ -161,50 +163,6 @@ class Densenet(pl.LightningModule):
             )
             self.normalizer = make_z_normalizer(dataloader)
 
-    def balance_losses_by_class(
-        self,
-        train_dataloader: DataLoader,
-        valid_dataloader: dict[str, DataLoader],
-    ):
-        """Reweights loss weights if possible.
-
-        Parameters
-        ----------
-
-        train_dataloader
-            The data loader to use for training
-
-        valid_dataloader
-            The data loaders to use for each of the validation sets
-
-
-        Raises
-        ------
-
-        RuntimeError
-            If train or validation losses are not of type
-            :py:class:`torch.nn.BCEWithLogitsLoss`.
-        """
-        from .loss_weights import get_label_weights
-
-        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss training loss.")
-            weights = get_label_weights(train_dataloader)
-            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
-        else:
-            raise RuntimeError(
-                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
-            )
-
-        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
-            weights = get_label_weights(valid_dataloader)
-            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
-        else:
-            raise RuntimeError(
-                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
-            )
-
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
diff --git a/src/ptbench/models/loss_weights.py b/src/ptbench/models/loss_weights.py
index 6889b253..8cf79b77 100644
--- a/src/ptbench/models/loss_weights.py
+++ b/src/ptbench/models/loss_weights.py
@@ -7,10 +7,12 @@ import logging
 import torch
 import torch.utils.data
 
+from ..data.typing import DataLoader
+
 logger = logging.getLogger(__name__)
 
 
-def get_label_weights(
+def _get_label_weights(
     dataloader: torch.utils.data.DataLoader,
 ) -> torch.Tensor:
     """Computes the weights of each class of a DataLoader.
@@ -68,3 +70,23 @@ def get_label_weights(
         )
 
     return positive_weights
+
+
+def make_balanced_bcewithlogitsloss(
+    dataloader: DataLoader,
+) -> torch.nn.BCEWithLogitsLoss:
+    """Returns a balanced binary-cross-entropy loss.
+
+    The loss is weighted using the ratio between positives and total examples
+    available.
+
+
+    Returns
+    -------
+
+    loss
+        An instance of the weighted loss
+    """
+
+    weights = _get_label_weights(dataloader)
+    return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index aaa5a2b0..479ec8f2 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -13,7 +13,7 @@ import torch.optim.optimizer
 import torch.utils.data
 import torchvision.transforms
 
-from ..data.typing import DataLoader, TransformSequence
+from ..data.typing import TransformSequence
 from .typing import Checkpoint
 
 logger = logging.getLogger(__name__)
@@ -192,7 +192,8 @@ class Pasa(pl.LightningModule):
     def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
         """Called by Lightning to restore your model.
 
-        If you saved something with on_save_checkpoint() this is your chance to restore this.
+        If you saved something with on_save_checkpoint() this is your chance to
+        restore this.
 
         Parameters
         ----------
@@ -232,50 +233,6 @@ class Pasa(pl.LightningModule):
         )
         self.normalizer = make_z_normalizer(dataloader)
 
-    def balance_losses_by_class(
-        self,
-        train_dataloader: DataLoader,
-        valid_dataloader: dict[str, DataLoader],
-    ):
-        """Reweights loss weights if possible.
-
-        Parameters
-        ----------
-
-        train_dataloader
-            The data loader to use for training
-
-        valid_dataloader
-            The data loaders to use for each of the validation sets
-
-
-        Raises
-        ------
-
-        RuntimeError
-            If train or validation losses are not of type
-            :py:class:`torch.nn.BCEWithLogitsLoss`.
-        """
-        from .loss_weights import get_label_weights
-
-        if isinstance(self._train_loss, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss training loss.")
-            weights = get_label_weights(train_dataloader)
-            self._train_loss = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
-        else:
-            raise RuntimeError(
-                "Training loss is not BCEWithLogitsLoss - dunno how to balance"
-            )
-
-        if isinstance(self._validation_loss, torch.nn.BCEWithLogitsLoss):
-            logger.info("Reweighting BCEWithLogitsLoss validation loss.")
-            weights = get_label_weights(valid_dataloader)
-            self._validation_loss = torch.nn.BCEWithLogitsLoss(weights)
-        else:
-            raise RuntimeError(
-                "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
-            )
-
     def training_step(self, batch, _):
         images = batch[0]
         labels = batch[1]["label"]
-- 
GitLab