Skip to content
Snippets Groups Projects
Commit 24c7fb68 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[model] Balance loss on train DataModule if no validation one

parent 7aaec166
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -16,7 +16,7 @@ from medbase.data.typing import TransformSequence ...@@ -16,7 +16,7 @@ from medbase.data.typing import TransformSequence
from .loss_weights import get_positive_weights from .loss_weights import get_positive_weights
from .typing import Checkpoint from .typing import Checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger("mednet")
class Model(pl.LightningModule): class Model(pl.LightningModule):
...@@ -216,7 +216,15 @@ class Model(pl.LightningModule): ...@@ -216,7 +216,15 @@ class Model(pl.LightningModule):
self._train_loss_arguments["pos_weight"] = train_weights self._train_loss_arguments["pos_weight"] = train_weights
logger.info(f"Balancing validation loss {self._loss_type}.") logger.info(f"Balancing validation loss {self._loss_type}.")
validation_weights = get_positive_weights( if "validation" in datamodule.val_dataloader().keys():
datamodule.val_dataloader()["validation"] validation_weights = get_positive_weights(
) datamodule.val_dataloader()["validation"]
)
else:
logger.warning(
"Datamodule does not contain a validation dataloader. The training dataloader will be used instead."
)
validation_weights = get_positive_weights(
datamodule.train_dataloader()
)
self._validation_loss_arguments["pos_weight"] = validation_weights self._validation_loss_arguments["pos_weight"] = validation_weights
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment