diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 580cc26fd24efc6ebb98fbcbb02d51348c6c11ec..952881525f16d7f9bf663e79fa500de538dec64f 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -95,7 +95,7 @@ class LoggingCallback(Callback): self.log("eta", eta_seconds) self.log("loss", numpy.average(self.training_loss)) self.log("learning_rate", pl_module.optimizer_configs["lr"]) - self.log("validation_loss", numpy.sum(self.validation_loss)) + self.log("validation_loss", numpy.average(self.validation_loss)) if len(self.extra_validation_loss) > 0: for ( @@ -103,7 +103,8 @@ class LoggingCallback(Callback): extra_valid_loss_values, ) in self.extra_validation_loss.items: self.log( - extra_valid_loss_key, numpy.sum(extra_valid_loss_values) + extra_valid_loss_key, + numpy.average(extra_valid_loss_values), ) queue_retries = 0