Skip to content
Snippets Groups Projects
callbacks.py 13.95 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import csv
import logging
import os
import pathlib
import time
import typing

import lightning.pytorch
import lightning.pytorch.callbacks
import torch

from ..utils.resources import ResourceMonitor

logger = logging.getLogger(__name__)


class LoggingCallback(lightning.pytorch.Callback):
    """Callback to log various training metrics and device information.

    It ensures CSVLogger logs training and evaluation metrics on the same line
    Note that a CSVLogger only accepts numerical values, and not strings.


    Parameters
    ----------

    resource_monitor
        A monitor that watches resource usage (CPU/GPU) in a separate process
        and totally asynchronously with the code execution.
    """

    def __init__(self, resource_monitor: ResourceMonitor):
        super().__init__()

        # lists of number of samples/batch and average losses
        # - we use this later to compute overall epoch losses
        self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
        self._validation_epoch_loss: dict[
            int, tuple[list[int], list[float]]
        ] = {}

        # timers
        self._start_training_time = 0.0
        self._start_training_epoch_time = 0.0
        self._start_validation_epoch_time = 0.0

        # log accumulators for a single flush at each training cycle
        self._to_log: dict[str, float] = {}

        # helpers for CPU and GPU utilisation
        self._resource_monitor = resource_monitor
        self._max_queue_retries = 2

    def on_train_start(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
    ):
        """Callback to be executed **before** the whole training starts.

        This method is executed whenever you *start* training a module.


        Parameters
        ---------

        trainer
            The Lightning trainer object

        pl_module
            The lightning module that is being trained
        """
        self._start_training_time = time.time()

    def on_train_epoch_start(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
    ) -> None:
        """Callback to be executed **before** every training batch starts.

        This method is executed whenever a training batch starts.  Presumably,
        batches happen as often as possible.  You want to make this code very
        fast.  Do not log things to the terminal or the such, or do complicated
        (lengthy) calculations.

        .. warning::

           This is executed **while** you are training.  Be very succint or
           face the consequences of slow training!


        Parameters
        ---------

        trainer
            The Lightning trainer object

        pl_module
            The lightning module that is being trained
        """
        self._start_training_epoch_time = time.time()
        self._training_epoch_loss = ([], [])

    def on_train_epoch_end(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
    ):
        """Callback to be executed **after** every training epoch ends.

        This method is executed whenever a training epoch ends.  Presumably,
        epochs happen as often as possible.  You want to make this code
        relatively fast to avoid significative runtime slow-downs.


        Parameters
        ----------

        trainer
            The Lightning trainer object

        pl_module
            The lightning module that is being trained
        """

        # summarizes resource usage since the last checkpoint
        # clears internal buffers and starts accumulating again.
        self._resource_monitor.checkpoint()

        # evaluates this training epoch total time, and log it
        epoch_time = time.time() - self._start_training_epoch_time

        # Compute overall training loss considering batches and sizes
        # We disconsider accumulate_grad_batches and assume they were all of
        # the same size.  This way, the average of averages is the overall
        # average.
        self._to_log["loss/train"] = torch.mean(
            torch.tensor(self._training_epoch_loss[0])
            * torch.tensor(self._training_epoch_loss[1])
        ).item()

        self._to_log["epoch-duration-seconds/train"] = epoch_time
        self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"]

        metrics = self._resource_monitor.data
        if metrics is not None:
            for metric_name, metric_value in metrics.items():
                self._to_log[f"{metric_name}/train"] = float(metric_value)
        else:
            logger.warning(
                "Unable to fetch monitoring information from "
                "resource monitor. CPU/GPU utilisation will be "
                "missing."
            )

        # if no validation dataloaders, complete cycle by the end of the
        # training epoch, by logging all values to the logger
        self.on_cycle_end(trainer, pl_module)

    def on_train_batch_end(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
        outputs: typing.Mapping[str, torch.Tensor],
        batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
        batch_idx: int,
    ) -> None:
        """Callback to be executed **after** every training batch ends.

        This method is executed whenever a training batch ends.  Presumably,
        batches happen as often as possible.  You want to make this code very
        fast.  Do not log things to the terminal or the such, or do complicated
        (lengthy) calculations.

        .. warning::

           This is executed **while** you are training.  Be very succint or
           face the consequences of slow training!


        Parameters
        ----------

        trainer
            The Lightning trainer object

        pl_module
            The lightning module that is being trained

        outputs
            The outputs of the module's ``training_step``

        batch
            The data that the training step received

        batch_idx
            The relative number of the batch
        """
        self._training_epoch_loss[0].append(batch[0].shape[0])
        self._training_epoch_loss[1].append(outputs["loss"].item())

    def on_validation_epoch_start(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
    ) -> None:
        """Callback to be executed **before** every validation batch starts.

        This method is executed whenever a validation batch starts.  Presumably,
        batches happen as often as possible.  You want to make this code very
        fast.  Do not log things to the terminal or the such, or do complicated
        (lengthy) calculations.

        .. warning::

           This is executed **while** you are training.  Be very succint or
           face the consequences of slow training!


        Parameters
        ---------

        trainer
            The Lightning trainer object

        pl_module
            The lightning module that is being trained
        """
        self._start_validation_epoch_time = time.time()
        self._validation_epoch_loss = {}

    def on_validation_epoch_end(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
    ) -> None:
        """Callback to be executed **after** every validation epoch ends.

        This method is executed whenever a validation epoch ends.  Presumably,
        epochs happen as often as possible.  You want to make this code
        relatively fast to avoid significative runtime slow-downs.


        Parameters
        ----------

        trainer
            The Lightning trainer object

        pl_module
            The lightning module that is being trained
        """

        # summarizes resource usage since the last checkpoint
        # clears internal buffers and starts accumulating again.
        self._resource_monitor.checkpoint()

        epoch_time = time.time() - self._start_validation_epoch_time
        self._to_log["epoch-duration-seconds/validation"] = epoch_time

        metrics = self._resource_monitor.data
        if metrics is not None:
            for metric_name, metric_value in metrics.items():
                self._to_log[f"{metric_name}/validation"] = float(metric_value)
        else:
            logger.warning(
                "Unable to fetch monitoring information from "
                "resource monitor. CPU/GPU utilisation will be "
                "missing."
            )

        # Compute overall validation losses considering batches and sizes
        # We disconsider accumulate_grad_batches and assume they were all
        # of the same size.  This way, the average of averages is the
        # overall average.
        for key in sorted(self._validation_epoch_loss.keys()):
            if key == 0:
                name = "loss/validation"
            else:
                name = f"loss/validation-{key}"

            self._to_log[name] = torch.mean(
                torch.tensor(self._validation_epoch_loss[key][0])
                * torch.tensor(self._validation_epoch_loss[key][1])
            ).item()

    def on_validation_batch_end(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
        outputs: torch.Tensor,
        batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        """Callback to be executed **after** every validation batch ends.

        This method is executed whenever a validation batch ends.  Presumably,
        batches happen as often as possible.  You want to make this code very
        fast.  Do not log things to the terminal or the such, or do complicated
        (lengthy) calculations.

        .. warning::

           This is executed **while** you are training.  Be very succint or
           face the consequences of slow training!


        Parameters
        ----------

        trainer
            The Lightning trainer object

        pl_module
            The lightning module that is being trained

        outputs
            The outputs of the module's ``training_step``

        batch
            The data that the training step received

        batch_idx
            The relative number of the batch

        dataloader_idx
            Index of the dataloader used during validation.  Use this to figure
            out which dataset was used for this validation epoch.
        """
        size, value = self._validation_epoch_loss.setdefault(
            dataloader_idx, ([], [])
        )
        size.append(batch[0].shape[0])
        value.append(outputs.item())

    def on_cycle_end(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
    ) -> None:
        """Called when the training/validation cycle has ended.

        This function will log all relevant values to the various loggers.  It
        is supposed to be called by the end of the training cycle (consisting
        of a training and validation step).


        Parameters
        ----------

        trainer
            The Lightning trainer object

        pl_module
            The lightning module that is being trained
        """

        # collect some final time for the whole training cycle
        # Note: logging should happen at on_validation_end(), but
        # apparently you can't log from there
        overall_cycle_time = time.time() - self._start_training_epoch_time
        self._to_log["cycle-time-seconds/train"] = overall_cycle_time
        self._to_log["total-execution-time-seconds"] = (
            time.time() - self._start_training_time
        )
        self._to_log["eta-seconds"] = overall_cycle_time * (
            trainer.max_epochs - trainer.current_epoch  # type: ignore
        )

        # Do not log during sanity check as results are not relevant
        if not trainer.sanity_checking:
            for k in sorted(self._to_log.keys()):
                pl_module.log_dict(
                    {k: self._to_log[k], "step": float(trainer.current_epoch)}
                )
            self._to_log = {}


class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
    """Lightning callback to write predictions to a file."""

    def __init__(
        self,
        output_dir: str | pathlib.Path,
        logfile_fields: typing.Sequence[str],
        write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"],
    ):
        super().__init__(write_interval)
        self.output_dir = output_dir
        self.logfile_fields = logfile_fields

    def write_on_epoch_end(
        self,
        trainer: lightning.pytorch.Trainer,
        pl_module: lightning.pytorch.LightningModule,
        predictions: typing.Sequence[typing.Any],
        batch_indices: typing.Sequence[typing.Any] | None,
    ) -> None:
        dataloader_names = list(trainer.datamodule.predict_dataloader().keys())

        for dataloader_idx, dataloader_name in enumerate(dataloader_names):
            logfile = os.path.join(
                self.output_dir,
                f"predictions_{dataloader_name}",
                "predictions.csv",
            )
            os.makedirs(os.path.dirname(logfile), exist_ok=True)

            logger.info(f"Saving predictions in {logfile}.")

            with open(logfile, "w") as l_f:
                logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
                logwriter.writeheader()

                for prediction in predictions[dataloader_idx]:
                    logwriter.writerow(
                        {
                            "filename": prediction[0],
                            "likelihood": prediction[1].numpy(),
                            "ground_truth": prediction[2].numpy(),
                        }
                    )
                l_f.flush()