From ae55d1841e492f8056a6db887f33b3bb676c5f07 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Thu, 2 May 2024 13:00:11 +0200 Subject: [PATCH] [logging] Prevent PL from adding dataloader index to logged metrics --- src/mednet/engine/callbacks.py | 8 +++++++- src/mednet/engine/trainer.py | 2 -- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py index 501f2c25..b1b16f65 100644 --- a/src/mednet/engine/callbacks.py +++ b/src/mednet/engine/callbacks.py @@ -362,11 +362,17 @@ class LoggingCallback(lightning.pytorch.Callback): out which dataset was used for this validation epoch. """ + if dataloader_idx == 0: + key = "loss/validation" + else: + key = f"loss/validation-{dataloader_idx}" + pl_module.log( - "loss/validation", + key, outputs.item(), prog_bar=False, on_step=False, on_epoch=True, batch_size=batch[0].shape[0], + add_dataloader_idx=False, ) diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 0993354f..e4a2f5ea 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -94,8 +94,6 @@ def run( ) monitor_key = "loss/validation" - if len(datamodule.val_dataset_keys()) > 1: - monitor_key = "loss/validation/dataloader_idx_0" # This checkpointer will operate at the end of every validation epoch # (which happens at each checkpoint period), it will then save the lowest -- GitLab