From 73f436fd33efddc257ea83b007e5f495108aeae2 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 15 Dec 2023 13:35:07 +0100 Subject: [PATCH] [engine.callbacks] Refactor callbacks to delegate most work to lightning --- src/ptbench/engine/callbacks.py | 135 ++++++++++++-------------------- src/ptbench/engine/trainer.py | 2 +- 2 files changed, 51 insertions(+), 86 deletions(-) diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 8669718e..6966f6fe 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -18,8 +18,12 @@ logger = logging.getLogger(__name__) class LoggingCallback(lightning.pytorch.Callback): """Callback to log various training metrics and device information. - It ensures CSVLogger logs training and evaluation metrics on the same line - Note that a CSVLogger only accepts numerical values, and not strings. + Rationale: + + 1. Losses are logged at the end of every batch, accumulated and handled by + the lightning framework + 2. Everything else is done at the end of a training or validation epoch and + mostly concerns runtime metrics such as memory and cpu/gpu utilisation. Parameters @@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback): def __init__(self, resource_monitor: ResourceMonitor): super().__init__() - # lists of number of samples/batch and average losses - # - we use this later to compute overall epoch losses - self._training_epoch_loss: tuple[list[int], list[float]] = ([], []) - self._validation_epoch_loss: dict[ - int, tuple[list[int], list[float]] - ] = {} - # timers self._start_training_time = 0.0 self._start_training_epoch_time = 0.0 @@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback): The lightning module that is being trained """ self._start_training_epoch_time = time.time() - self._training_epoch_loss = ([], []) def on_train_epoch_end( self, @@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback): # evaluates this training epoch total time, and log it epoch_time = time.time() - self._start_training_epoch_time - # Compute overall training loss considering batches and sizes - # We disconsider accumulate_grad_batches and assume they were all of - # the same size. This way, the average of averages is the overall - # average. - self._to_log["loss/train"] = torch.mean( - torch.tensor(self._training_epoch_loss[0]) - * torch.tensor(self._training_epoch_loss[1]) - ).item() - self._to_log["epoch-duration-seconds/train"] = epoch_time - self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] + self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore metrics = self._resource_monitor.data if metrics is not None: @@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback): "missing." ) - # if no validation dataloaders, complete cycle by the end of the - # training epoch, by logging all values to the logger - self.on_cycle_end(trainer, pl_module) + overall_cycle_time = time.time() - self._start_training_epoch_time + self._to_log["cycle-time-seconds/train"] = overall_cycle_time + self._to_log["total-execution-time-seconds"] = ( + time.time() - self._start_training_time + ) + self._to_log["eta-seconds"] = overall_cycle_time * ( + trainer.max_epochs - trainer.current_epoch # type: ignore + ) + # the "step" is the tensorboard jargon for "epoch" or "batch", + # depending on how we are logging - in a more general way, it simply + # means the relative time step. + self._to_log["step"] = float(trainer.current_epoch) + + # Do not log during sanity check as results are not relevant + if not trainer.sanity_checking: + pl_module.log_dict(self._to_log) + self._to_log = {} def on_train_batch_end( self, @@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback): batch_idx The relative number of the batch """ - self._training_epoch_loss[0].append(batch[0].shape[0]) - self._training_epoch_loss[1].append(outputs["loss"].item()) + pl_module.log( + "loss/train", + outputs["loss"].item(), + prog_bar=True, + on_step=False, + on_epoch=True, + batch_size=batch[0].shape[0], + ) def on_validation_epoch_start( self, @@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback): The lightning module that is being trained """ self._start_validation_epoch_time = time.time() - self._validation_epoch_loss = {} def on_validation_epoch_end( self, @@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback): "missing." ) - # Compute overall validation losses considering batches and sizes - # We disconsider accumulate_grad_batches and assume they were all - # of the same size. This way, the average of averages is the - # overall average. - for key in sorted(self._validation_epoch_loss.keys()): - if key == 0: - name = "loss/validation" - else: - name = f"loss/validation-{key}" - - self._to_log[name] = torch.mean( - torch.tensor(self._validation_epoch_loss[key][0]) - * torch.tensor(self._validation_epoch_loss[key][1]) - ).item() + self._to_log["step"] = float(trainer.current_epoch) + + # Do not log during sanity check as results are not relevant + if not trainer.sanity_checking: + pl_module.log_dict(self._to_log) + self._to_log = {} def on_validation_batch_end( self, @@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback): Index of the dataloader used during validation. Use this to figure out which dataset was used for this validation epoch. """ - size, value = self._validation_epoch_loss.setdefault( - dataloader_idx, ([], []) - ) - size.append(batch[0].shape[0]) - value.append(outputs.item()) - - def on_cycle_end( - self, - trainer: lightning.pytorch.Trainer, - pl_module: lightning.pytorch.LightningModule, - ) -> None: - """Called when the training/validation cycle has ended. - - This function will log all relevant values to the various loggers. It - is supposed to be called by the end of the training cycle (consisting - of a training and validation step). - - - Parameters - ---------- - - trainer - The Lightning trainer object - - pl_module - The lightning module that is being trained - """ - # collect some final time for the whole training cycle - # Note: logging should happen at on_validation_end(), but - # apparently you can't log from there - overall_cycle_time = time.time() - self._start_training_epoch_time - self._to_log["cycle-time-seconds/train"] = overall_cycle_time - self._to_log["total-execution-time-seconds"] = ( - time.time() - self._start_training_time - ) - self._to_log["eta-seconds"] = overall_cycle_time * ( - trainer.max_epochs - trainer.current_epoch # type: ignore + if dataloader_idx == 0: + key = "loss/validation" + else: + key = f"loss/validation-{dataloader_idx}" + + pl_module.log( + key, + outputs.item(), + prog_bar=False, + on_step=False, + on_epoch=True, + batch_size=batch[0].shape[0], ) - - # Do not log during sanity check as results are not relevant - if not trainer.sanity_checking: - for k in sorted(self._to_log.keys()): - pl_module.log_dict( - {k: self._to_log[k], "step": float(trainer.current_epoch)} - ) - self._to_log = {} diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index b4f878fe..dee1ad98 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -193,7 +193,7 @@ def run( save_last=True, # will (re)create the last trained model, at every iteration monitor="loss/validation", mode="min", - save_on_train_epoch_end=True, # run checks at the end of validation + save_on_train_epoch_end=True, every_n_epochs=validation_period, # frequency at which it would check the "monitor" enable_version_counter=False, # no versioning of aliased checkpoints ) -- GitLab