Skip to content
Snippets Groups Projects
Commit ae55d184 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[logging] Prevent PL from adding dataloader index to logged metrics

parent 8769c327
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
...@@ -362,11 +362,17 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -362,11 +362,17 @@ class LoggingCallback(lightning.pytorch.Callback):
out which dataset was used for this validation epoch. out which dataset was used for this validation epoch.
""" """
if dataloader_idx == 0:
key = "loss/validation"
else:
key = f"loss/validation-{dataloader_idx}"
pl_module.log( pl_module.log(
"loss/validation", key,
outputs.item(), outputs.item(),
prog_bar=False, prog_bar=False,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
batch_size=batch[0].shape[0], batch_size=batch[0].shape[0],
add_dataloader_idx=False,
) )
...@@ -94,8 +94,6 @@ def run( ...@@ -94,8 +94,6 @@ def run(
) )
monitor_key = "loss/validation" monitor_key = "loss/validation"
if len(datamodule.val_dataset_keys()) > 1:
monitor_key = "loss/validation/dataloader_idx_0"
# This checkpointer will operate at the end of every validation epoch # This checkpointer will operate at the end of every validation epoch
# (which happens at each checkpoint period), it will then save the lowest # (which happens at each checkpoint period), it will then save the lowest
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment