Skip to content
Snippets Groups Projects
Commit 73f436fd authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[engine.callbacks] Refactor callbacks to delegate most work to lightning

parent 50bc1099
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
......@@ -18,8 +18,12 @@ logger = logging.getLogger(__name__)
class LoggingCallback(lightning.pytorch.Callback):
"""Callback to log various training metrics and device information.
It ensures CSVLogger logs training and evaluation metrics on the same line
Note that a CSVLogger only accepts numerical values, and not strings.
Rationale:
1. Losses are logged at the end of every batch, accumulated and handled by
the lightning framework
2. Everything else is done at the end of a training or validation epoch and
mostly concerns runtime metrics such as memory and cpu/gpu utilisation.
Parameters
......@@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback):
def __init__(self, resource_monitor: ResourceMonitor):
super().__init__()
# lists of number of samples/batch and average losses
# - we use this later to compute overall epoch losses
self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
self._validation_epoch_loss: dict[
int, tuple[list[int], list[float]]
] = {}
# timers
self._start_training_time = 0.0
self._start_training_epoch_time = 0.0
......@@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained
"""
self._start_training_epoch_time = time.time()
self._training_epoch_loss = ([], [])
def on_train_epoch_end(
self,
......@@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback):
# evaluates this training epoch total time, and log it
epoch_time = time.time() - self._start_training_epoch_time
# Compute overall training loss considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all of
# the same size. This way, the average of averages is the overall
# average.
self._to_log["loss/train"] = torch.mean(
torch.tensor(self._training_epoch_loss[0])
* torch.tensor(self._training_epoch_loss[1])
).item()
self._to_log["epoch-duration-seconds/train"] = epoch_time
self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"]
self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore
metrics = self._resource_monitor.data
if metrics is not None:
......@@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback):
"missing."
)
# if no validation dataloaders, complete cycle by the end of the
# training epoch, by logging all values to the logger
self.on_cycle_end(trainer, pl_module)
overall_cycle_time = time.time() - self._start_training_epoch_time
self._to_log["cycle-time-seconds/train"] = overall_cycle_time
self._to_log["total-execution-time-seconds"] = (
time.time() - self._start_training_time
)
self._to_log["eta-seconds"] = overall_cycle_time * (
trainer.max_epochs - trainer.current_epoch # type: ignore
)
# the "step" is the tensorboard jargon for "epoch" or "batch",
# depending on how we are logging - in a more general way, it simply
# means the relative time step.
self._to_log["step"] = float(trainer.current_epoch)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
pl_module.log_dict(self._to_log)
self._to_log = {}
def on_train_batch_end(
self,
......@@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback):
batch_idx
The relative number of the batch
"""
self._training_epoch_loss[0].append(batch[0].shape[0])
self._training_epoch_loss[1].append(outputs["loss"].item())
pl_module.log(
"loss/train",
outputs["loss"].item(),
prog_bar=True,
on_step=False,
on_epoch=True,
batch_size=batch[0].shape[0],
)
def on_validation_epoch_start(
self,
......@@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained
"""
self._start_validation_epoch_time = time.time()
self._validation_epoch_loss = {}
def on_validation_epoch_end(
self,
......@@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback):
"missing."
)
# Compute overall validation losses considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all
# of the same size. This way, the average of averages is the
# overall average.
for key in sorted(self._validation_epoch_loss.keys()):
if key == 0:
name = "loss/validation"
else:
name = f"loss/validation-{key}"
self._to_log[name] = torch.mean(
torch.tensor(self._validation_epoch_loss[key][0])
* torch.tensor(self._validation_epoch_loss[key][1])
).item()
self._to_log["step"] = float(trainer.current_epoch)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
pl_module.log_dict(self._to_log)
self._to_log = {}
def on_validation_batch_end(
self,
......@@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback):
Index of the dataloader used during validation. Use this to figure
out which dataset was used for this validation epoch.
"""
size, value = self._validation_epoch_loss.setdefault(
dataloader_idx, ([], [])
)
size.append(batch[0].shape[0])
value.append(outputs.item())
def on_cycle_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Called when the training/validation cycle has ended.
This function will log all relevant values to the various loggers. It
is supposed to be called by the end of the training cycle (consisting
of a training and validation step).
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# collect some final time for the whole training cycle
# Note: logging should happen at on_validation_end(), but
# apparently you can't log from there
overall_cycle_time = time.time() - self._start_training_epoch_time
self._to_log["cycle-time-seconds/train"] = overall_cycle_time
self._to_log["total-execution-time-seconds"] = (
time.time() - self._start_training_time
)
self._to_log["eta-seconds"] = overall_cycle_time * (
trainer.max_epochs - trainer.current_epoch # type: ignore
if dataloader_idx == 0:
key = "loss/validation"
else:
key = f"loss/validation-{dataloader_idx}"
pl_module.log(
key,
outputs.item(),
prog_bar=False,
on_step=False,
on_epoch=True,
batch_size=batch[0].shape[0],
)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
for k in sorted(self._to_log.keys()):
pl_module.log_dict(
{k: self._to_log[k], "step": float(trainer.current_epoch)}
)
self._to_log = {}
......@@ -193,7 +193,7 @@ def run(
save_last=True, # will (re)create the last trained model, at every iteration
monitor="loss/validation",
mode="min",
save_on_train_epoch_end=True, # run checks at the end of validation
save_on_train_epoch_end=True,
every_n_epochs=validation_period, # frequency at which it would check the "monitor"
enable_version_counter=False, # no versioning of aliased checkpoints
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment