From 1c03b05b0f528673e730b816e45c21b5d1bb1425 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 2 Jul 2024 10:31:20 +0200 Subject: [PATCH] [libs.common.models.model] Do not assume pos_weight is a scalar --- src/mednet/libs/common/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index cad52097..cc4bf033 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -228,7 +228,7 @@ class Model(pl.LightningModule): self._train_loss_arguments["pos_weight"] = train_weights logger.info( f"Balanced training loss {self._loss_type}: " - f"`pos_weight={train_weights.item():.3f}`." + f"`pos_weight={train_weights}`." ) if "validation" in datamodule.val_dataloader().keys(): @@ -245,7 +245,7 @@ class Model(pl.LightningModule): self._validation_loss_arguments["pos_weight"] = validation_weights logger.info( f"Balanced validation loss {self._loss_type}: " - f"`pos_weight={validation_weights.item():.3f}`." + f"`pos_weight={validation_weights}`." ) # re-instantiates losses for train and validation -- GitLab