From 24c7fb6809c53015f7d2874548ff7007fe745cb7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 13 May 2024 12:07:32 +0200
Subject: [PATCH] [model] Balance loss on train DataModule if no validation one

---
 src/mednet/libs/common/models/model.py | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py
index e1b264be..d01cc14a 100644
--- a/src/mednet/libs/common/models/model.py
+++ b/src/mednet/libs/common/models/model.py
@@ -16,7 +16,7 @@ from medbase.data.typing import TransformSequence
 from .loss_weights import get_positive_weights
 from .typing import Checkpoint
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("mednet")
 
 
 class Model(pl.LightningModule):
@@ -216,7 +216,15 @@ class Model(pl.LightningModule):
             self._train_loss_arguments["pos_weight"] = train_weights
 
             logger.info(f"Balancing validation loss {self._loss_type}.")
-            validation_weights = get_positive_weights(
-                datamodule.val_dataloader()["validation"]
-            )
+            if "validation" in datamodule.val_dataloader().keys():
+                validation_weights = get_positive_weights(
+                    datamodule.val_dataloader()["validation"]
+                )
+            else:
+                logger.warning(
+                    "Datamodule does not contain a validation dataloader. The training dataloader will be used instead."
+                )
+                validation_weights = get_positive_weights(
+                    datamodule.train_dataloader()
+                )
             self._validation_loss_arguments["pos_weight"] = validation_weights
-- 
GitLab