Skip to content
Snippets Groups Projects
Commit 73f436fd authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[engine.callbacks] Refactor callbacks to delegate most work to lightning

parent 50bc1099
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
...@@ -18,8 +18,12 @@ logger = logging.getLogger(__name__) ...@@ -18,8 +18,12 @@ logger = logging.getLogger(__name__)
class LoggingCallback(lightning.pytorch.Callback): class LoggingCallback(lightning.pytorch.Callback):
"""Callback to log various training metrics and device information. """Callback to log various training metrics and device information.
It ensures CSVLogger logs training and evaluation metrics on the same line Rationale:
Note that a CSVLogger only accepts numerical values, and not strings.
1. Losses are logged at the end of every batch, accumulated and handled by
the lightning framework
2. Everything else is done at the end of a training or validation epoch and
mostly concerns runtime metrics such as memory and cpu/gpu utilisation.
Parameters Parameters
...@@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback):
def __init__(self, resource_monitor: ResourceMonitor): def __init__(self, resource_monitor: ResourceMonitor):
super().__init__() super().__init__()
# lists of number of samples/batch and average losses
# - we use this later to compute overall epoch losses
self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
self._validation_epoch_loss: dict[
int, tuple[list[int], list[float]]
] = {}
# timers # timers
self._start_training_time = 0.0 self._start_training_time = 0.0
self._start_training_epoch_time = 0.0 self._start_training_epoch_time = 0.0
...@@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained The lightning module that is being trained
""" """
self._start_training_epoch_time = time.time() self._start_training_epoch_time = time.time()
self._training_epoch_loss = ([], [])
def on_train_epoch_end( def on_train_epoch_end(
self, self,
...@@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback):
# evaluates this training epoch total time, and log it # evaluates this training epoch total time, and log it
epoch_time = time.time() - self._start_training_epoch_time epoch_time = time.time() - self._start_training_epoch_time
# Compute overall training loss considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all of
# the same size. This way, the average of averages is the overall
# average.
self._to_log["loss/train"] = torch.mean(
torch.tensor(self._training_epoch_loss[0])
* torch.tensor(self._training_epoch_loss[1])
).item()
self._to_log["epoch-duration-seconds/train"] = epoch_time self._to_log["epoch-duration-seconds/train"] = epoch_time
self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore
metrics = self._resource_monitor.data metrics = self._resource_monitor.data
if metrics is not None: if metrics is not None:
...@@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback):
"missing." "missing."
) )
# if no validation dataloaders, complete cycle by the end of the overall_cycle_time = time.time() - self._start_training_epoch_time
# training epoch, by logging all values to the logger self._to_log["cycle-time-seconds/train"] = overall_cycle_time
self.on_cycle_end(trainer, pl_module) self._to_log["total-execution-time-seconds"] = (
time.time() - self._start_training_time
)
self._to_log["eta-seconds"] = overall_cycle_time * (
trainer.max_epochs - trainer.current_epoch # type: ignore
)
# the "step" is the tensorboard jargon for "epoch" or "batch",
# depending on how we are logging - in a more general way, it simply
# means the relative time step.
self._to_log["step"] = float(trainer.current_epoch)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
pl_module.log_dict(self._to_log)
self._to_log = {}
def on_train_batch_end( def on_train_batch_end(
self, self,
...@@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback):
batch_idx batch_idx
The relative number of the batch The relative number of the batch
""" """
self._training_epoch_loss[0].append(batch[0].shape[0]) pl_module.log(
self._training_epoch_loss[1].append(outputs["loss"].item()) "loss/train",
outputs["loss"].item(),
prog_bar=True,
on_step=False,
on_epoch=True,
batch_size=batch[0].shape[0],
)
def on_validation_epoch_start( def on_validation_epoch_start(
self, self,
...@@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained The lightning module that is being trained
""" """
self._start_validation_epoch_time = time.time() self._start_validation_epoch_time = time.time()
self._validation_epoch_loss = {}
def on_validation_epoch_end( def on_validation_epoch_end(
self, self,
...@@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback):
"missing." "missing."
) )
# Compute overall validation losses considering batches and sizes self._to_log["step"] = float(trainer.current_epoch)
# We disconsider accumulate_grad_batches and assume they were all
# of the same size. This way, the average of averages is the # Do not log during sanity check as results are not relevant
# overall average. if not trainer.sanity_checking:
for key in sorted(self._validation_epoch_loss.keys()): pl_module.log_dict(self._to_log)
if key == 0: self._to_log = {}
name = "loss/validation"
else:
name = f"loss/validation-{key}"
self._to_log[name] = torch.mean(
torch.tensor(self._validation_epoch_loss[key][0])
* torch.tensor(self._validation_epoch_loss[key][1])
).item()
def on_validation_batch_end( def on_validation_batch_end(
self, self,
...@@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback):
Index of the dataloader used during validation. Use this to figure Index of the dataloader used during validation. Use this to figure
out which dataset was used for this validation epoch. out which dataset was used for this validation epoch.
""" """
size, value = self._validation_epoch_loss.setdefault(
dataloader_idx, ([], [])
)
size.append(batch[0].shape[0])
value.append(outputs.item())
def on_cycle_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Called when the training/validation cycle has ended.
This function will log all relevant values to the various loggers. It
is supposed to be called by the end of the training cycle (consisting
of a training and validation step).
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# collect some final time for the whole training cycle if dataloader_idx == 0:
# Note: logging should happen at on_validation_end(), but key = "loss/validation"
# apparently you can't log from there else:
overall_cycle_time = time.time() - self._start_training_epoch_time key = f"loss/validation-{dataloader_idx}"
self._to_log["cycle-time-seconds/train"] = overall_cycle_time
self._to_log["total-execution-time-seconds"] = ( pl_module.log(
time.time() - self._start_training_time key,
) outputs.item(),
self._to_log["eta-seconds"] = overall_cycle_time * ( prog_bar=False,
trainer.max_epochs - trainer.current_epoch # type: ignore on_step=False,
on_epoch=True,
batch_size=batch[0].shape[0],
) )
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
for k in sorted(self._to_log.keys()):
pl_module.log_dict(
{k: self._to_log[k], "step": float(trainer.current_epoch)}
)
self._to_log = {}
...@@ -193,7 +193,7 @@ def run( ...@@ -193,7 +193,7 @@ def run(
save_last=True, # will (re)create the last trained model, at every iteration save_last=True, # will (re)create the last trained model, at every iteration
monitor="loss/validation", monitor="loss/validation",
mode="min", mode="min",
save_on_train_epoch_end=True, # run checks at the end of validation save_on_train_epoch_end=True,
every_n_epochs=validation_period, # frequency at which it would check the "monitor" every_n_epochs=validation_period, # frequency at which it would check the "monitor"
enable_version_counter=False, # no versioning of aliased checkpoints enable_version_counter=False, # no versioning of aliased checkpoints
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment