diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py index 501f2c254b2d511993adadab2735f12d44fe73c1..b1b16f65ddca03d1d4fdd5a980999fca6aa18349 100644 --- a/src/mednet/engine/callbacks.py +++ b/src/mednet/engine/callbacks.py @@ -362,11 +362,17 @@ class LoggingCallback(lightning.pytorch.Callback): out which dataset was used for this validation epoch. """ + if dataloader_idx == 0: + key = "loss/validation" + else: + key = f"loss/validation-{dataloader_idx}" + pl_module.log( - "loss/validation", + key, outputs.item(), prog_bar=False, on_step=False, on_epoch=True, batch_size=batch[0].shape[0], + add_dataloader_idx=False, ) diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py index 0993354f8dc2d84b30f0ab18f95751621a6076fb..e4a2f5eac865ae870207657219b120b21cc9be64 100644 --- a/src/mednet/engine/trainer.py +++ b/src/mednet/engine/trainer.py @@ -94,8 +94,6 @@ def run( ) monitor_key = "loss/validation" - if len(datamodule.val_dataset_keys()) > 1: - monitor_key = "loss/validation/dataloader_idx_0" # This checkpointer will operate at the end of every validation epoch # (which happens at each checkpoint period), it will then save the lowest