From 222a376213777262323c77f229f3a9561b8aa836 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 30 Apr 2024 10:37:09 +0200
Subject: [PATCH] [train] Add support for multiple validation dataloaders

During validation logging, lightning seems to append "/dataloader_idx_n"
to the key we define if multiple dataloaders are used.
---
 src/mednet/data/datamodule.py  |  8 ++++----
 src/mednet/engine/callbacks.py |  7 +------
 src/mednet/engine/trainer.py   |  6 +++++-
 src/mednet/models/alexnet.py   |  3 +--
 src/mednet/models/densenet.py  |  2 +-
 src/mednet/models/model.py     | 33 +++++++++++++++++++++++++--------
 src/mednet/models/pasa.py      |  3 +--
 7 files changed, 38 insertions(+), 24 deletions(-)

diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py
index 2de54fdf..6c7d759f 100644
--- a/src/mednet/data/datamodule.py
+++ b/src/mednet/data/datamodule.py
@@ -758,7 +758,7 @@ class ConcatDataModule(lightning.LightningDataModule):
         else:
             self._datasets[name] = _ConcatDataset(datasets)
 
-    def _val_dataset_keys(self) -> list[str]:
+    def val_dataset_keys(self) -> list[str]:
         """Return list of validation dataset names.
 
         Returns
@@ -796,11 +796,11 @@ class ConcatDataModule(lightning.LightningDataModule):
         """
 
         if stage == "fit":
-            for k in ["train"] + self._val_dataset_keys():
+            for k in ["train"] + self.val_dataset_keys():
                 self._setup_dataset(k)
 
         elif stage == "validate":
-            for k in self._val_dataset_keys():
+            for k in self.val_dataset_keys():
                 self._setup_dataset(k)
 
         elif stage == "test":
@@ -889,7 +889,7 @@ class ConcatDataModule(lightning.LightningDataModule):
                 self._datasets[k],
                 **validation_loader_opts,
             )
-            for k in self._val_dataset_keys()
+            for k in self.val_dataset_keys()
         }
 
     def test_dataloader(self) -> dict[str, DataLoader]:
diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py
index 4c19ac55..501f2c25 100644
--- a/src/mednet/engine/callbacks.py
+++ b/src/mednet/engine/callbacks.py
@@ -362,13 +362,8 @@ class LoggingCallback(lightning.pytorch.Callback):
             out which dataset was used for this validation epoch.
         """
 
-        if dataloader_idx == 0:
-            key = "loss/validation"
-        else:
-            key = f"loss/validation-{dataloader_idx}"
-
         pl_module.log(
-            key,
+            "loss/validation",
             outputs.item(),
             prog_bar=False,
             on_step=False,
diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py
index 23df024b..d3a345b7 100644
--- a/src/mednet/engine/trainer.py
+++ b/src/mednet/engine/trainer.py
@@ -91,6 +91,10 @@ def run(
         main_pid=os.getpid(),
     )
 
+    monitor_key = "loss/validation"
+    if len(datamodule.val_dataset_keys()) > 1:
+        monitor_key = "loss/validation/dataloader_idx_0"
+
     # This checkpointer will operate at the end of every validation epoch
     # (which happens at each checkpoint period), it will then save the lowest
     # validation loss model observed.  It will also save the last trained model
@@ -98,7 +102,7 @@ def run(
         dirpath=output_folder,
         filename=CHECKPOINT_ALIASES["best"],
         save_last=True,  # will (re)create the last trained model, at every iteration
-        monitor="loss/validation",
+        monitor=monitor_key,
         mode="min",
         save_on_train_epoch_end=True,
         every_n_epochs=validation_period,  # frequency at which it checks the "monitor"
diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py
index 22b98baa..eada55b8 100644
--- a/src/mednet/models/alexnet.py
+++ b/src/mednet/models/alexnet.py
@@ -166,8 +166,7 @@ class Alexnet(Model):
 
         # data forwarding on the existing network
         outputs = self(images)
-
-        return self._validation_loss(outputs, labels.float())
+        return self._validation_loss[dataloader_idx](outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py
index fcdb9f95..15da7f4e 100644
--- a/src/mednet/models/densenet.py
+++ b/src/mednet/models/densenet.py
@@ -164,7 +164,7 @@ class Densenet(Model):
         # data forwarding on the existing network
         outputs = self(images)
 
-        return self._validation_loss(outputs, labels.float())
+        return self._validation_loss[dataloader_idx](outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index a0b3701e..155ff19c 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -68,9 +68,10 @@ class Model(pl.LightningModule):
         self.model_transforms: TransformSequence = []
 
         self._train_loss = train_loss
-        self._validation_loss = (
-            validation_loss if validation_loss is not None else train_loss
-        )
+        self._validation_loss = [
+            (validation_loss if validation_loss is not None else train_loss)
+        ]
+
         self._optimizer_type = optimizer_type
         self._optimizer_arguments = optimizer_arguments
 
@@ -163,16 +164,32 @@ class Model(pl.LightningModule):
             setattr(self._train_loss, "pos_weight", train_weights)
 
         logger.info(
-            f"Balancing validation loss function {self._validation_loss}."
+            f"Balancing validation loss function {self._validation_loss[0]}."
         )
         try:
-            getattr(self._validation_loss, "pos_weight")
+            getattr(self._validation_loss[0], "pos_weight")
         except AttributeError:
             logger.warning(
                 "Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
             )
         else:
-            validation_weights = _get_label_weights(
-                datamodule.val_dataloader()["validation"]
+            # If multiple validation DataLoaders are used, each one will need to have a loss
+            # that is balanced for that DataLoader
+
+            new_validation_losses = []
+            loss_class = self._validation_loss[0].__class__
+
+            datamodule_validation_keys = datamodule.val_dataset_keys()
+            logger.info(
+                f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key."
             )
-            setattr(self._validation_loss, "pos_weight", validation_weights)
+
+            for val_dataset_key in datamodule_validation_keys:
+                validation_weights = _get_label_weights(
+                    datamodule.val_dataloader()[val_dataset_key]
+                )
+                new_validation_losses.append(
+                    loss_class(pos_weight=validation_weights)
+                )
+
+            self._validation_loss = new_validation_losses
diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py
index 389eac8c..54032eda 100644
--- a/src/mednet/models/pasa.py
+++ b/src/mednet/models/pasa.py
@@ -233,8 +233,7 @@ class Pasa(Model):
 
         # data forwarding on the existing network
         outputs = self(images)
-
-        return self._validation_loss(outputs, labels.float())
+        return self._validation_loss[dataloader_idx](outputs, labels.float())
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0):
         outputs = self(batch[0])
-- 
GitLab