Skip to content
Snippets Groups Projects

Making use of LightningDataModule and simplification of data loading

Merged Daniel CARRON requested to merge add-datamodule into main
3 files
+ 530
350
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 356
107
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import csv
import logging
import os
import pathlib
import time
import typing
from collections import defaultdict
import lightning.pytorch
import lightning.pytorch.callbacks
import torch
import numpy
from ..utils.resources import ResourceMonitor
from lightning.pytorch import Callback
from lightning.pytorch.callbacks import BasePredictionWriter
logger = logging.getLogger(__name__)
# This ensures CSVLogger logs training and evaluation metrics on the same line
# CSVLogger only accepts numerical values, not strings
class LoggingCallback(Callback):
"""Lightning callback to log various training metrics and device
information."""
class LoggingCallback(lightning.pytorch.Callback):
"""Callback to log various training metrics and device information.
def __init__(self, resource_monitor):
super().__init__()
self.training_loss = []
self.validation_loss = []
self.extra_validation_loss = defaultdict(list)
self.start_training_time = 0
self.start_epoch_time = 0
It ensures CSVLogger logs training and evaluation metrics on the same line
Note that a CSVLogger only accepts numerical values, and not strings.
self.resource_monitor = resource_monitor
self.max_queue_retries = 2
def on_train_start(self, trainer, pl_module):
self.start_training_time = time.time()
Parameters
----------
def on_train_epoch_start(self, trainer, pl_module):
self.start_epoch_time = time.time()
resource_monitor
A monitor that watches resource usage (CPU/GPU) in a separate process
and totally asynchronously with the code execution.
"""
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.training_loss.append(outputs["loss"].item())
def __init__(self, resource_monitor: ResourceMonitor):
super().__init__()
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx
# lists of number of samples/batch and average losses
# - we use this later to compute overall epoch losses
self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
self._validation_epoch_loss: dict[
int, tuple[list[int], list[float]]
] = {}
# timers
self._start_training_time = 0.0
self._start_training_epoch_time = 0.0
self._start_validation_epoch_time = 0.0
# log accumulators for a single flush at each training cycle
self._to_log: dict[str, float] = {}
# helpers for CPU and GPU utilisation
self._resource_monitor = resource_monitor
self._max_queue_retries = 2
def on_train_start(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
):
self.validation_loss.append(outputs["validation_loss"].item())
"""Callback to be executed **before** the whole training starts.
if len(outputs) > 1:
extra_validation_keys = outputs.keys().remove("validation_loss")
for extra_validation_loss_key in extra_validation_keys:
self.extra_validation_loss[extra_validation_loss_key].append(
outputs[extra_validation_loss_key]
)
This method is executed whenever you *start* training a module.
def on_validation_epoch_end(self, trainer, pl_module):
self.resource_monitor.trigger_summary()
self.epoch_time = time.time() - self.start_epoch_time
eta_seconds = self.epoch_time * (
trainer.max_epochs - trainer.current_epoch
)
current_time = time.time() - self.start_training_time
Parameters
---------
def _compute_batch_loss(losses, num_chunks):
# When accumulating gradients, partial losses need to be summed per batch before averaging
if num_chunks != 1:
# The loss we get is scaled by the number of accumulation steps
losses = numpy.multiply(losses, num_chunks)
trainer
The Lightning trainer object
if len(losses) % num_chunks > 0:
num_splits = (len(losses) // num_chunks) + 1
else:
num_splits = len(losses) // num_chunks
pl_module
The lightning module that is being trained
"""
self._start_training_time = time.time()
batched_losses = numpy.array_split(losses, num_splits)
def on_train_epoch_start(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Callback to be executed **before** every training batch starts.
summed_batch_losses = []
This method is executed whenever a training batch starts. Presumably,
batches happen as often as possible. You want to make this code very
fast. Do not log things to the terminal or the such, or do complicated
(lengthy) calculations.
for b in batched_losses:
summed_batch_losses.append(numpy.average(b))
.. warning::
return summed_batch_losses
This is executed **while** you are training. Be very succint or
face the consequences of slow training!
# No gradient accumulation, we already have the batch losses
else:
return losses
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
# We get partial loses when using gradient accumulation
self.training_loss = _compute_batch_loss(
self.training_loss, trainer.accumulate_grad_batches
)
self.validation_loss = _compute_batch_loss(
self.validation_loss, trainer.accumulate_grad_batches
)
Parameters
---------
self.log("total_time", current_time)
self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss))
self.log("learning_rate", pl_module.optimizer_configs["lr"])
self.log("validation_loss", numpy.average(self.validation_loss))
if len(self.extra_validation_loss) > 0:
for (
extra_valid_loss_key,
extra_valid_loss_values,
) in self.extra_validation_loss.items:
self.log(
extra_valid_loss_key,
numpy.average(extra_valid_loss_values),
)
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
self._start_training_epoch_time = time.time()
self._training_epoch_loss = ([], [])
queue_retries = 0
# In case the resource monitor takes longer to fetch data from the queue, we wait
# Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue
while (
self.resource_monitor.data is None
and queue_retries < self.max_queue_retries
):
queue_retries = queue_retries + 1
print(
f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s"
def on_train_epoch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
):
"""Callback to be executed **after** every training epoch ends.
This method is executed whenever a training epoch ends. Presumably,
epochs happen as often as possible. You want to make this code
relatively fast to avoid significative runtime slow-downs.
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# summarizes resource usage since the last checkpoint
# clears internal buffers and starts accumulating again.
self._resource_monitor.checkpoint()
# evaluates this training epoch total time, and log it
epoch_time = time.time() - self._start_training_epoch_time
# Compute overall training loss considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all of
# the same size. This way, the average of averages is the overall
# average.
self._to_log["train_loss"] = torch.mean(
torch.tensor(self._training_epoch_loss[0])
* torch.tensor(self._training_epoch_loss[1])
).item()
self._to_log["train_epoch_time"] = epoch_time
self._to_log["learning_rate"] = pl_module.optimizers().defaults["lr"]
metrics = self._resource_monitor.data
if metrics is not None:
for metric_name, metric_value in metrics.items():
self._to_log[f"train_{metric_name}"] = float(metric_value)
else:
logger.warning(
"Unable to fetch monitoring information from "
"resource monitor. CPU/GPU utilisation will be "
"missing."
)
time.sleep(self.resource_monitor.interval)
if queue_retries >= self.max_queue_retries:
print(
f"Unable to fetch monitoring information from queue after {queue_retries} retries"
# if no validation dataloaders, complete cycle by the end of the
# training epoch, by logging all values to the logger
self.on_cycle_end(trainer, pl_module)
def on_train_batch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
outputs: typing.Mapping[str, torch.Tensor],
batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
batch_idx: int,
) -> None:
"""Callback to be executed **after** every training batch ends.
This method is executed whenever a training batch ends. Presumably,
batches happen as often as possible. You want to make this code very
fast. Do not log things to the terminal or the such, or do complicated
(lengthy) calculations.
.. warning::
This is executed **while** you are training. Be very succint or
face the consequences of slow training!
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
outputs
The outputs of the module's ``training_step``
batch
The data that the training step received
batch_idx
The relative number of the batch
"""
self._training_epoch_loss[0].append(batch[0].shape[0])
self._training_epoch_loss[1].append(outputs["loss"].item())
def on_validation_epoch_start(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Callback to be executed **before** every validation batch starts.
This method is executed whenever a validation batch starts. Presumably,
batches happen as often as possible. You want to make this code very
fast. Do not log things to the terminal or the such, or do complicated
(lengthy) calculations.
.. warning::
This is executed **while** you are training. Be very succint or
face the consequences of slow training!
Parameters
---------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
self._start_validation_epoch_time = time.time()
self._validation_epoch_loss = {}
def on_validation_epoch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Callback to be executed **after** every validation epoch ends.
This method is executed whenever a validation epoch ends. Presumably,
epochs happen as often as possible. You want to make this code
relatively fast to avoid significative runtime slow-downs.
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# summarizes resource usage since the last checkpoint
# clears internal buffers and starts accumulating again.
self._resource_monitor.checkpoint()
epoch_time = time.time() - self._start_validation_epoch_time
self._to_log["validation_epoch_time"] = epoch_time
metrics = self._resource_monitor.data
if metrics is not None:
for metric_name, metric_value in metrics.items():
self._to_log[f"validation_{metric_name}"] = float(metric_value)
else:
logger.warning(
"Unable to fetch monitoring information from "
"resource monitor. CPU/GPU utilisation will be "
"missing."
)
assert self.resource_monitor.q.empty()
# Compute overall validation losses considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all
# of the same size. This way, the average of averages is the
# overall average.
for key in sorted(self._validation_epoch_loss.keys()):
if key == 0:
name = "validation_loss"
else:
name = f"validation_loss_{key}"
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
for metric_name, metric_value in self.resource_monitor.data:
self.log(metric_name, float(metric_value))
self._to_log[name] = torch.mean(
torch.tensor(self._validation_epoch_loss[key][0])
* torch.tensor(self._validation_epoch_loss[key][1])
).item()
def on_validation_batch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
outputs: torch.Tensor,
batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Callback to be executed **after** every validation batch ends.
This method is executed whenever a validation batch ends. Presumably,
batches happen as often as possible. You want to make this code very
fast. Do not log things to the terminal or the such, or do complicated
(lengthy) calculations.
self.resource_monitor.data = None
.. warning::
self.training_loss = []
self.validation_loss = []
This is executed **while** you are training. Be very succint or
face the consequences of slow training!
class PredictionsWriter(BasePredictionWriter):
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
outputs
The outputs of the module's ``training_step``
batch
The data that the training step received
batch_idx
The relative number of the batch
dataloader_idx
Index of the dataloader used during validation. Use this to figure
out which dataset was used for this validation epoch.
"""
size, value = self._validation_epoch_loss.setdefault(
dataloader_idx, ([], [])
)
size.append(batch[0].shape[0])
value.append(outputs.item())
def on_cycle_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Called when the training/validation cycle has ended.
This function will log all relevant values to the various loggers. It
is supposed to be called by the end of the training cycle (consisting
of a training and validation step).
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# collect some final time for the whole training cycle
# Note: logging should happen at on_validation_end(), but
# apparently you can't log from there
overall_cycle_time = time.time() - self._start_training_epoch_time
self._to_log["train_cycle_time"] = overall_cycle_time
self._to_log["total_time"] = time.time() - self._start_training_time
self._to_log["eta"] = overall_cycle_time * (
trainer.max_epochs - trainer.current_epoch # type: ignore
)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
for k in sorted(self._to_log.keys()):
pl_module.log(k, self._to_log[k])
self._to_log = {}
class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
"""Lightning callback to write predictions to a file."""
def __init__(self, output_dir, logfile_fields, write_interval):
def __init__(
self,
output_dir: str | pathlib.Path,
logfile_fields: typing.Sequence[str],
write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"],
):
super().__init__(write_interval)
self.output_dir = output_dir
self.logfile_fields = logfile_fields
def write_on_epoch_end(
self, trainer, pl_module, predictions, batch_indices
):
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
predictions: typing.Sequence[typing.Any],
batch_indices: typing.Sequence[typing.Any] | None,
) -> None:
for dataloader_idx, dataloader_results in enumerate(predictions):
dataloader_name = list(
trainer.datamodule.predict_dataloader().keys()
Loading