From 8769c32708572dfd15f57e5c56abc9bf4f2b451a Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 2 May 2024 12:27:32 +0200
Subject: [PATCH] [model] Use a single type of loss for train and validation

---
 src/mednet/config/models/alexnet.py           |  3 +-
 .../config/models/alexnet_pretrained.py       |  3 +-
 src/mednet/config/models/densenet.py          |  3 +-
 .../config/models/densenet_pretrained.py      |  3 +-
 src/mednet/config/models/densenet_rs.py       |  3 +-
 src/mednet/config/models/pasa.py              |  3 +-
 src/mednet/models/alexnet.py                  | 31 +++-------
 src/mednet/models/densenet.py                 | 31 +++-------
 src/mednet/models/model.py                    | 58 ++++++-------------
 src/mednet/models/pasa.py                     | 31 +++-------
 10 files changed, 47 insertions(+), 122 deletions(-)

diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py
index 9a1cf542..7f281867 100644
--- a/src/mednet/config/models/alexnet.py
+++ b/src/mednet/config/models/alexnet.py
@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.alexnet import Alexnet
 
 model = Alexnet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py
index ea9198ab..a9356555 100644
--- a/src/mednet/config/models/alexnet_pretrained.py
+++ b/src/mednet/config/models/alexnet_pretrained.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.alexnet import Alexnet
 
 model = Alexnet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=SGD,
     optimizer_arguments=dict(lr=0.01, momentum=0.1),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py
index 7154bb74..9ee510ac 100644
--- a/src/mednet/config/models/densenet.py
+++ b/src/mednet/config/models/densenet.py
@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py
index 1025a689..b7e2efcd 100644
--- a/src/mednet/config/models/densenet_pretrained.py
+++ b/src/mednet/config/models/densenet_pretrained.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py
index 23c1a064..813bb76c 100644
--- a/src/mednet/config/models/densenet_rs.py
+++ b/src/mednet/config/models/densenet_rs.py
@@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.densenet import Densenet
 
 model = Densenet(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=0.0001),
     augmentation_transforms=[ElasticDeformation(p=0.2)],
diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py
index 16457d1a..7787d10e 100644
--- a/src/mednet/config/models/pasa.py
+++ b/src/mednet/config/models/pasa.py
@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
 from mednet.models.pasa import Pasa
 
 model = Pasa(
-    train_loss_type=BCEWithLogitsLoss,
-    validation_loss_type=BCEWithLogitsLoss,
+    loss_type=BCEWithLogitsLoss,
     optimizer_type=Adam,
     optimizer_arguments=dict(lr=8e-5),
     augmentation_transforms=[ElasticDeformation(p=0.8)],
diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index d4e2586d..75223c9a 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -27,26 +27,15 @@ class Alexnet(Model):
 
     Parameters
     ----------
-    train_loss_type
-        The loss to be used during the training.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    train_loss_arguments
-        Arguments to the training loss.
-    validation_loss_type
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss_arguments
-        Arguments to the validation loss.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -63,10 +52,8 @@ class Alexnet(Model):
 
     def __init__(
         self,
-        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
-        train_loss_arguments: dict[str, typing.Any] = {},
-        validation_loss_type: torch.nn.Module | None = None,
-        validation_loss_arguments: dict[str, typing.Any] = {},
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -74,10 +61,8 @@ class Alexnet(Model):
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss_type,
-            train_loss_arguments,
-            validation_loss_type,
-            validation_loss_arguments,
+            loss_type,
+            loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index c90ab10c..76df1ed6 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -25,26 +25,15 @@ class Densenet(Model):
 
     Parameters
     ----------
-    train_loss_type
-        The loss to be used during the training.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    train_loss_arguments
-        Arguments to the training loss.
-    validation_loss_type
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss_arguments
-        Arguments to the validation loss.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -63,10 +52,8 @@ class Densenet(Model):
 
     def __init__(
         self,
-        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
-        train_loss_arguments: dict[str, typing.Any] = {},
-        validation_loss_type: torch.nn.Module | None = None,
-        validation_loss_arguments: dict[str, typing.Any] = {},
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -75,10 +62,8 @@ class Densenet(Model):
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss_type,
-            train_loss_arguments,
-            validation_loss_type,
-            validation_loss_arguments,
+            loss_type,
+            loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index 27e5c32d..57c1b618 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -24,26 +24,15 @@ class Model(pl.LightningModule):
 
     Parameters
     ----------
-    train_loss_type
-        The loss to be used during the training.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    train_loss_arguments
-        Arguments to the training loss.
-    validation_loss_type
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss_arguments
-        Arguments to the validation loss.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -57,10 +46,8 @@ class Model(pl.LightningModule):
 
     def __init__(
         self,
-        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
-        train_loss_arguments: dict[str, typing.Any] = {},
-        validation_loss_type: torch.nn.Module | None = None,
-        validation_loss_arguments: dict[str, typing.Any] = {},
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
@@ -73,13 +60,13 @@ class Model(pl.LightningModule):
 
         self.model_transforms: TransformSequence = []
 
-        self._train_loss_type = train_loss_type
-        self._train_loss_arguments = train_loss_arguments
+        self._loss_type = loss_type
+
         self._train_loss = None
+        self._train_loss_arguments = loss_arguments
 
-        self._validation_loss_type = validation_loss_type
-        self._validation_loss_arguments = validation_loss_arguments
         self.validation_loss = None
+        self._validation_loss_arguments = loss_arguments
 
         self._optimizer_type = optimizer_type
         self._optimizer_arguments = optimizer_arguments
@@ -148,8 +135,8 @@ class Model(pl.LightningModule):
         raise NotImplementedError
 
     def configure_losses(self):
-        self._train_loss = self._train_loss_type(**self._train_loss_arguments)
-        self._validation_loss = self._validation_loss_type(
+        self._train_loss = self._loss_type(**self._train_loss_arguments)
+        self._validation_loss = self._loss_type(
             **self._validation_loss_arguments
         )
 
@@ -160,7 +147,7 @@ class Model(pl.LightningModule):
         )
 
     def balance_losses(self, datamodule) -> None:
-        """Balance the loss based on the distribution of targets in the datamodule, if the loss function supports it.
+        """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
 
         Parameters
         ----------
@@ -168,29 +155,18 @@ class Model(pl.LightningModule):
             Instance of a datamodule.
         """
 
-        logger.info(
-            f"Balancing training loss function {self._train_loss_type}."
-        )
         try:
-            getattr(self._train_loss_type(), "pos_weight")
+            getattr(self._loss_type(), "pos_weight")
         except AttributeError:
             logger.warning(
-                "Training loss does not posess a 'pos_weight' attribute and will not be balanced."
+                f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
             )
         else:
+            logger.info(f"Balancing training loss {self._loss_type}.")
             train_weights = get_positive_weights(datamodule.train_dataloader())
             self._train_loss_arguments["pos_weight"] = train_weights
 
-        logger.info(
-            f"Balancing validation loss function {self._validation_loss_type}."
-        )
-        try:
-            getattr(self._validation_loss_type(), "pos_weight")
-        except AttributeError:
-            logger.warning(
-                "Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
-            )
-        else:
+            logger.info(f"Balancing validation loss {self._loss_type}.")
             validation_weights = get_positive_weights(
                 datamodule.val_dataloader()["validation"]
             )
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index 5e5cd743..e9147683 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -30,26 +30,15 @@ class Pasa(Model):
 
     Parameters
     ----------
-    train_loss_type
-        The loss to be used during the training.
+    loss_type
+        The loss to be used for training and evaluation.
 
         .. warning::
 
            The loss should be set to always return batch averages (as opposed
            to the batch sum), as our logging system expects it so.
-    train_loss_arguments
-        Arguments to the training loss.
-    validation_loss_type
-        The loss to be used for validation (may be different from the training
-        loss).  If extra-validation sets are provided, the same loss will be
-        used throughout.
-
-        .. warning::
-
-           The loss should be set to always return batch averages (as opposed
-           to the batch sum), as our logging system expects it so.
-    validation_loss_arguments
-        Arguments to the validation loss.
+    loss_arguments
+        Arguments to the loss.
     optimizer_type
         The type of optimizer to use for training.
     optimizer_arguments
@@ -63,20 +52,16 @@ class Pasa(Model):
 
     def __init__(
         self,
-        train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
-        train_loss_arguments: dict[str, typing.Any] = {},
-        validation_loss_type: torch.nn.Module | None = None,
-        validation_loss_arguments: dict[str, typing.Any] = {},
+        loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
+        loss_arguments: dict[str, typing.Any] = {},
         optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
         optimizer_arguments: dict[str, typing.Any] = {},
         augmentation_transforms: TransformSequence = [],
         num_classes: int = 1,
     ):
         super().__init__(
-            train_loss_type,
-            train_loss_arguments,
-            validation_loss_type,
-            validation_loss_arguments,
+            loss_type,
+            loss_arguments,
             optimizer_type,
             optimizer_arguments,
             augmentation_transforms,
-- 
GitLab