diff --git a/src/mednet/engine/callbacks.py b/src/mednet/engine/callbacks.py
index 501f2c254b2d511993adadab2735f12d44fe73c1..b1b16f65ddca03d1d4fdd5a980999fca6aa18349 100644
--- a/src/mednet/engine/callbacks.py
+++ b/src/mednet/engine/callbacks.py
@@ -362,11 +362,17 @@ class LoggingCallback(lightning.pytorch.Callback):
             out which dataset was used for this validation epoch.
         """
 
+        if dataloader_idx == 0:
+            key = "loss/validation"
+        else:
+            key = f"loss/validation-{dataloader_idx}"
+
         pl_module.log(
-            "loss/validation",
+            key,
             outputs.item(),
             prog_bar=False,
             on_step=False,
             on_epoch=True,
             batch_size=batch[0].shape[0],
+            add_dataloader_idx=False,
         )
diff --git a/src/mednet/engine/trainer.py b/src/mednet/engine/trainer.py
index 0993354f8dc2d84b30f0ab18f95751621a6076fb..e4a2f5eac865ae870207657219b120b21cc9be64 100644
--- a/src/mednet/engine/trainer.py
+++ b/src/mednet/engine/trainer.py
@@ -94,8 +94,6 @@ def run(
     )
 
     monitor_key = "loss/validation"
-    if len(datamodule.val_dataset_keys()) > 1:
-        monitor_key = "loss/validation/dataloader_idx_0"
 
     # This checkpointer will operate at the end of every validation epoch
     # (which happens at each checkpoint period), it will then save the lowest