From 73f436fd33efddc257ea83b007e5f495108aeae2 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 15 Dec 2023 13:35:07 +0100
Subject: [PATCH] [engine.callbacks] Refactor callbacks to delegate most work
 to lightning

---
 src/ptbench/engine/callbacks.py | 135 ++++++++++++--------------------
 src/ptbench/engine/trainer.py   |   2 +-
 2 files changed, 51 insertions(+), 86 deletions(-)

diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 8669718e..6966f6fe 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -18,8 +18,12 @@ logger = logging.getLogger(__name__)
 class LoggingCallback(lightning.pytorch.Callback):
     """Callback to log various training metrics and device information.
 
-    It ensures CSVLogger logs training and evaluation metrics on the same line
-    Note that a CSVLogger only accepts numerical values, and not strings.
+    Rationale:
+
+    1. Losses are logged at the end of every batch, accumulated and handled by
+       the lightning framework
+    2. Everything else is done at the end of a training or validation epoch and
+       mostly concerns runtime metrics such as memory and cpu/gpu utilisation.
 
 
     Parameters
@@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback):
     def __init__(self, resource_monitor: ResourceMonitor):
         super().__init__()
 
-        # lists of number of samples/batch and average losses
-        # - we use this later to compute overall epoch losses
-        self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
-        self._validation_epoch_loss: dict[
-            int, tuple[list[int], list[float]]
-        ] = {}
-
         # timers
         self._start_training_time = 0.0
         self._start_training_epoch_time = 0.0
@@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback):
             The lightning module that is being trained
         """
         self._start_training_epoch_time = time.time()
-        self._training_epoch_loss = ([], [])
 
     def on_train_epoch_end(
         self,
@@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback):
         # evaluates this training epoch total time, and log it
         epoch_time = time.time() - self._start_training_epoch_time
 
-        # Compute overall training loss considering batches and sizes
-        # We disconsider accumulate_grad_batches and assume they were all of
-        # the same size.  This way, the average of averages is the overall
-        # average.
-        self._to_log["loss/train"] = torch.mean(
-            torch.tensor(self._training_epoch_loss[0])
-            * torch.tensor(self._training_epoch_loss[1])
-        ).item()
-
         self._to_log["epoch-duration-seconds/train"] = epoch_time
-        self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"]
+        self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"]  # type: ignore
 
         metrics = self._resource_monitor.data
         if metrics is not None:
@@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback):
                 "missing."
             )
 
-        # if no validation dataloaders, complete cycle by the end of the
-        # training epoch, by logging all values to the logger
-        self.on_cycle_end(trainer, pl_module)
+        overall_cycle_time = time.time() - self._start_training_epoch_time
+        self._to_log["cycle-time-seconds/train"] = overall_cycle_time
+        self._to_log["total-execution-time-seconds"] = (
+            time.time() - self._start_training_time
+        )
+        self._to_log["eta-seconds"] = overall_cycle_time * (
+            trainer.max_epochs - trainer.current_epoch  # type: ignore
+        )
+        # the "step" is the tensorboard jargon for "epoch" or "batch",
+        # depending on how we are logging - in a more general way, it simply
+        # means the relative time step.
+        self._to_log["step"] = float(trainer.current_epoch)
+
+        # Do not log during sanity check as results are not relevant
+        if not trainer.sanity_checking:
+            pl_module.log_dict(self._to_log)
+            self._to_log = {}
 
     def on_train_batch_end(
         self,
@@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback):
         batch_idx
             The relative number of the batch
         """
-        self._training_epoch_loss[0].append(batch[0].shape[0])
-        self._training_epoch_loss[1].append(outputs["loss"].item())
+        pl_module.log(
+            "loss/train",
+            outputs["loss"].item(),
+            prog_bar=True,
+            on_step=False,
+            on_epoch=True,
+            batch_size=batch[0].shape[0],
+        )
 
     def on_validation_epoch_start(
         self,
@@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback):
             The lightning module that is being trained
         """
         self._start_validation_epoch_time = time.time()
-        self._validation_epoch_loss = {}
 
     def on_validation_epoch_end(
         self,
@@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback):
                 "missing."
             )
 
-        # Compute overall validation losses considering batches and sizes
-        # We disconsider accumulate_grad_batches and assume they were all
-        # of the same size.  This way, the average of averages is the
-        # overall average.
-        for key in sorted(self._validation_epoch_loss.keys()):
-            if key == 0:
-                name = "loss/validation"
-            else:
-                name = f"loss/validation-{key}"
-
-            self._to_log[name] = torch.mean(
-                torch.tensor(self._validation_epoch_loss[key][0])
-                * torch.tensor(self._validation_epoch_loss[key][1])
-            ).item()
+        self._to_log["step"] = float(trainer.current_epoch)
+
+        # Do not log during sanity check as results are not relevant
+        if not trainer.sanity_checking:
+            pl_module.log_dict(self._to_log)
+            self._to_log = {}
 
     def on_validation_batch_end(
         self,
@@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback):
             Index of the dataloader used during validation.  Use this to figure
             out which dataset was used for this validation epoch.
         """
-        size, value = self._validation_epoch_loss.setdefault(
-            dataloader_idx, ([], [])
-        )
-        size.append(batch[0].shape[0])
-        value.append(outputs.item())
-
-    def on_cycle_end(
-        self,
-        trainer: lightning.pytorch.Trainer,
-        pl_module: lightning.pytorch.LightningModule,
-    ) -> None:
-        """Called when the training/validation cycle has ended.
-
-        This function will log all relevant values to the various loggers.  It
-        is supposed to be called by the end of the training cycle (consisting
-        of a training and validation step).
-
-
-        Parameters
-        ----------
-
-        trainer
-            The Lightning trainer object
-
-        pl_module
-            The lightning module that is being trained
-        """
 
-        # collect some final time for the whole training cycle
-        # Note: logging should happen at on_validation_end(), but
-        # apparently you can't log from there
-        overall_cycle_time = time.time() - self._start_training_epoch_time
-        self._to_log["cycle-time-seconds/train"] = overall_cycle_time
-        self._to_log["total-execution-time-seconds"] = (
-            time.time() - self._start_training_time
-        )
-        self._to_log["eta-seconds"] = overall_cycle_time * (
-            trainer.max_epochs - trainer.current_epoch  # type: ignore
+        if dataloader_idx == 0:
+            key = "loss/validation"
+        else:
+            key = f"loss/validation-{dataloader_idx}"
+
+        pl_module.log(
+            key,
+            outputs.item(),
+            prog_bar=False,
+            on_step=False,
+            on_epoch=True,
+            batch_size=batch[0].shape[0],
         )
-
-        # Do not log during sanity check as results are not relevant
-        if not trainer.sanity_checking:
-            for k in sorted(self._to_log.keys()):
-                pl_module.log_dict(
-                    {k: self._to_log[k], "step": float(trainer.current_epoch)}
-                )
-            self._to_log = {}
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index b4f878fe..dee1ad98 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -193,7 +193,7 @@ def run(
         save_last=True,  # will (re)create the last trained model, at every iteration
         monitor="loss/validation",
         mode="min",
-        save_on_train_epoch_end=True,  # run checks at the end of validation
+        save_on_train_epoch_end=True,
         every_n_epochs=validation_period,  # frequency at which it would check the "monitor"
         enable_version_counter=False,  # no versioning of aliased checkpoints
     )
-- 
GitLab