From 87285a8ea21ef1ba6d0c71ff0f753e48ef5870a7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 5 Jul 2023 11:42:21 +0200
Subject: [PATCH] Average validation loss instead of adding it

---
 src/ptbench/engine/callbacks.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 580cc26f..95288152 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -95,7 +95,7 @@ class LoggingCallback(Callback):
             self.log("eta", eta_seconds)
             self.log("loss", numpy.average(self.training_loss))
             self.log("learning_rate", pl_module.optimizer_configs["lr"])
-            self.log("validation_loss", numpy.sum(self.validation_loss))
+            self.log("validation_loss", numpy.average(self.validation_loss))
 
             if len(self.extra_validation_loss) > 0:
                 for (
@@ -103,7 +103,8 @@ class LoggingCallback(Callback):
                     extra_valid_loss_values,
                 ) in self.extra_validation_loss.items:
                     self.log(
-                        extra_valid_loss_key, numpy.sum(extra_valid_loss_values)
+                        extra_valid_loss_key,
+                        numpy.average(extra_valid_loss_values),
                     )
 
         queue_retries = 0
-- 
GitLab