From 24c7fb6809c53015f7d2874548ff7007fe745cb7 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 13 May 2024 12:07:32 +0200 Subject: [PATCH] [model] Balance loss on train DataModule if no validation one --- src/mednet/libs/common/models/model.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index e1b264be..d01cc14a 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -16,7 +16,7 @@ from medbase.data.typing import TransformSequence from .loss_weights import get_positive_weights from .typing import Checkpoint -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") class Model(pl.LightningModule): @@ -216,7 +216,15 @@ class Model(pl.LightningModule): self._train_loss_arguments["pos_weight"] = train_weights logger.info(f"Balancing validation loss {self._loss_type}.") - validation_weights = get_positive_weights( - datamodule.val_dataloader()["validation"] - ) + if "validation" in datamodule.val_dataloader().keys(): + validation_weights = get_positive_weights( + datamodule.val_dataloader()["validation"] + ) + else: + logger.warning( + "Datamodule does not contain a validation dataloader. The training dataloader will be used instead." + ) + validation_weights = get_positive_weights( + datamodule.train_dataloader() + ) self._validation_loss_arguments["pos_weight"] = validation_weights -- GitLab