diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index e1b264bee83eec12c46f6286a2c2add78e15f69e..d01cc14a4c7e58d10cb8adb5fe875c9198761d9e 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -16,7 +16,7 @@ from medbase.data.typing import TransformSequence from .loss_weights import get_positive_weights from .typing import Checkpoint -logger = logging.getLogger(__name__) +logger = logging.getLogger("mednet") class Model(pl.LightningModule): @@ -216,7 +216,15 @@ class Model(pl.LightningModule): self._train_loss_arguments["pos_weight"] = train_weights logger.info(f"Balancing validation loss {self._loss_type}.") - validation_weights = get_positive_weights( - datamodule.val_dataloader()["validation"] - ) + if "validation" in datamodule.val_dataloader().keys(): + validation_weights = get_positive_weights( + datamodule.val_dataloader()["validation"] + ) + else: + logger.warning( + "Datamodule does not contain a validation dataloader. The training dataloader will be used instead." + ) + validation_weights = get_positive_weights( + datamodule.train_dataloader() + ) self._validation_loss_arguments["pos_weight"] = validation_weights