From dccc1da3d3f402449aeac02f9321d05599936305 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 10 Jul 2023 11:17:03 +0200
Subject: [PATCH] [ptbench.engine] Simplified, documented and created type
 hints for the ``callbacks`` module; Added type hints to the resource module

---
 src/ptbench/engine/callbacks.py | 463 ++++++++++++++++++++++++--------
 src/ptbench/engine/trainer.py   | 111 ++------
 src/ptbench/utils/resources.py  | 306 +++++++++++----------
 3 files changed, 530 insertions(+), 350 deletions(-)

diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 95288152..350140a8 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -1,154 +1,403 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
 import csv
+import logging
 import os
+import pathlib
 import time
+import typing
 
-from collections import defaultdict
+import lightning.pytorch
+import lightning.pytorch.callbacks
+import torch
 
-import numpy
+from ..utils.resources import ResourceMonitor
 
-from lightning.pytorch import Callback
-from lightning.pytorch.callbacks import BasePredictionWriter
+logger = logging.getLogger(__name__)
 
 
-# This ensures CSVLogger logs training and evaluation metrics on the same line
-# CSVLogger only accepts numerical values, not strings
-class LoggingCallback(Callback):
-    """Lightning callback to log various training metrics and device
-    information."""
+class LoggingCallback(lightning.pytorch.Callback):
+    """Callback to log various training metrics and device information.
 
-    def __init__(self, resource_monitor):
-        super().__init__()
-        self.training_loss = []
-        self.validation_loss = []
-        self.extra_validation_loss = defaultdict(list)
-        self.start_training_time = 0
-        self.start_epoch_time = 0
+    It ensures CSVLogger logs training and evaluation metrics on the same line
+    Note that a CSVLogger only accepts numerical values, and not strings.
 
-        self.resource_monitor = resource_monitor
-        self.max_queue_retries = 2
 
-    def on_train_start(self, trainer, pl_module):
-        self.start_training_time = time.time()
+    Parameters
+    ----------
 
-    def on_train_epoch_start(self, trainer, pl_module):
-        self.start_epoch_time = time.time()
+    resource_monitor
+        A monitor that watches resource usage (CPU/GPU) in a separate process
+        and totally asynchronously with the code execution.
+    """
 
-    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
-        self.training_loss.append(outputs["loss"].item())
+    def __init__(self, resource_monitor: ResourceMonitor):
+        super().__init__()
 
-    def on_validation_batch_end(
-        self, trainer, pl_module, outputs, batch, batch_idx
+        # lists of number of samples/batch and average losses
+        # - we use this later to compute overall epoch losses
+        self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
+        self._validation_epoch_loss: dict[
+            int, tuple[list[int], list[float]]
+        ] = {}
+
+        # timers
+        self._start_training_time = 0.0
+        self._start_training_epoch_time = 0.0
+        self._start_validation_epoch_time = 0.0
+
+        # log accumulators for a single flush at each training cycle
+        self._to_log: dict[str, float] = {}
+
+        # helpers for CPU and GPU utilisation
+        self._resource_monitor = resource_monitor
+        self._max_queue_retries = 2
+
+    def on_train_start(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
     ):
-        self.validation_loss.append(outputs["validation_loss"].item())
+        """Callback to be executed **before** the whole training starts.
 
-        if len(outputs) > 1:
-            extra_validation_keys = outputs.keys().remove("validation_loss")
-            for extra_validation_loss_key in extra_validation_keys:
-                self.extra_validation_loss[extra_validation_loss_key].append(
-                    outputs[extra_validation_loss_key]
-                )
+        This method is executed whenever you *start* training a module.
 
-    def on_validation_epoch_end(self, trainer, pl_module):
-        self.resource_monitor.trigger_summary()
 
-        self.epoch_time = time.time() - self.start_epoch_time
-        eta_seconds = self.epoch_time * (
-            trainer.max_epochs - trainer.current_epoch
-        )
-        current_time = time.time() - self.start_training_time
+        Parameters
+        ---------
 
-        def _compute_batch_loss(losses, num_chunks):
-            # When accumulating gradients, partial losses need to be summed per batch before averaging
-            if num_chunks != 1:
-                # The loss we get is scaled by the number of accumulation steps
-                losses = numpy.multiply(losses, num_chunks)
+        trainer
+            The Lightning trainer object
 
-                if len(losses) % num_chunks > 0:
-                    num_splits = (len(losses) // num_chunks) + 1
-                else:
-                    num_splits = len(losses) // num_chunks
+        pl_module
+            The lightning module that is being trained
+        """
+        self._start_training_time = time.time()
 
-                batched_losses = numpy.array_split(losses, num_splits)
+    def on_train_epoch_start(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ) -> None:
+        """Callback to be executed **before** every training batch starts.
 
-                summed_batch_losses = []
+        This method is executed whenever a training batch starts.  Presumably,
+        batches happen as often as possible.  You want to make this code very
+        fast.  Do not log things to the terminal or the such, or do complicated
+        (lengthy) calculations.
 
-                for b in batched_losses:
-                    summed_batch_losses.append(numpy.average(b))
+        .. warning::
 
-                return summed_batch_losses
+           This is executed **while** you are training.  Be very succint or
+           face the consequences of slow training!
 
-            # No gradient accumulation, we already have the batch losses
-            else:
-                return losses
 
-        # Do not log during sanity check as results are not relevant
-        if not trainer.sanity_checking:
-            # We get partial loses when using gradient accumulation
-            self.training_loss = _compute_batch_loss(
-                self.training_loss, trainer.accumulate_grad_batches
-            )
-            self.validation_loss = _compute_batch_loss(
-                self.validation_loss, trainer.accumulate_grad_batches
-            )
+        Parameters
+        ---------
 
-            self.log("total_time", current_time)
-            self.log("eta", eta_seconds)
-            self.log("loss", numpy.average(self.training_loss))
-            self.log("learning_rate", pl_module.optimizer_configs["lr"])
-            self.log("validation_loss", numpy.average(self.validation_loss))
-
-            if len(self.extra_validation_loss) > 0:
-                for (
-                    extra_valid_loss_key,
-                    extra_valid_loss_values,
-                ) in self.extra_validation_loss.items:
-                    self.log(
-                        extra_valid_loss_key,
-                        numpy.average(extra_valid_loss_values),
-                    )
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+        self._start_training_epoch_time = time.time()
+        self._training_epoch_loss = ([], [])
 
-        queue_retries = 0
-        # In case the resource monitor takes longer to fetch data from the queue, we wait
-        # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue
-        while (
-            self.resource_monitor.data is None
-            and queue_retries < self.max_queue_retries
-        ):
-            queue_retries = queue_retries + 1
-            print(
-                f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s"
+    def on_train_epoch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ):
+        """Callback to be executed **after** every training epoch ends.
+
+        This method is executed whenever a training epoch ends.  Presumably,
+        epochs happen as often as possible.  You want to make this code
+        relatively fast to avoid significative runtime slow-downs.
+
+
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+
+        # summarizes resource usage since the last checkpoint
+        # clears internal buffers and starts accumulating again.
+        self._resource_monitor.checkpoint()
+
+        # evaluates this training epoch total time, and log it
+        epoch_time = time.time() - self._start_training_epoch_time
+
+        # Compute overall training loss considering batches and sizes
+        # We disconsider accumulate_grad_batches and assume they were all of
+        # the same size.  This way, the average of averages is the overall
+        # average.
+        self._to_log["train_loss"] = torch.mean(
+            torch.tensor(self._training_epoch_loss[0])
+            * torch.tensor(self._training_epoch_loss[1])
+        ).item()
+
+        self._to_log["train_epoch_time"] = epoch_time
+        self._to_log["learning_rate"] = pl_module.optimizers().defaults["lr"]
+
+        metrics = self._resource_monitor.data
+        if metrics is not None:
+            for metric_name, metric_value in metrics.items():
+                self._to_log[f"train_{metric_name}"] = float(metric_value)
+        else:
+            logger.warning(
+                "Unable to fetch monitoring information from "
+                "resource monitor. CPU/GPU utilisation will be "
+                "missing."
             )
-            time.sleep(self.resource_monitor.interval)
 
-        if queue_retries >= self.max_queue_retries:
-            print(
-                f"Unable to fetch monitoring information from queue after {queue_retries} retries"
+        # if no validation dataloaders, complete cycle by the end of the
+        # training epoch, by logging all values to the logger
+        self.on_cycle_end(trainer, pl_module)
+
+    def on_train_batch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+        outputs: typing.Mapping[str, torch.Tensor],
+        batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
+        batch_idx: int,
+    ) -> None:
+        """Callback to be executed **after** every training batch ends.
+
+        This method is executed whenever a training batch ends.  Presumably,
+        batches happen as often as possible.  You want to make this code very
+        fast.  Do not log things to the terminal or the such, or do complicated
+        (lengthy) calculations.
+
+        .. warning::
+
+           This is executed **while** you are training.  Be very succint or
+           face the consequences of slow training!
+
+
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+
+        outputs
+            The outputs of the module's ``training_step``
+
+        batch
+            The data that the training step received
+
+        batch_idx
+            The relative number of the batch
+        """
+        self._training_epoch_loss[0].append(batch[0].shape[0])
+        self._training_epoch_loss[1].append(outputs["loss"].item())
+
+    def on_validation_epoch_start(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ) -> None:
+        """Callback to be executed **before** every validation batch starts.
+
+        This method is executed whenever a validation batch starts.  Presumably,
+        batches happen as often as possible.  You want to make this code very
+        fast.  Do not log things to the terminal or the such, or do complicated
+        (lengthy) calculations.
+
+        .. warning::
+
+           This is executed **while** you are training.  Be very succint or
+           face the consequences of slow training!
+
+
+        Parameters
+        ---------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+        self._start_validation_epoch_time = time.time()
+        self._validation_epoch_loss = {}
+
+    def on_validation_epoch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ) -> None:
+        """Callback to be executed **after** every validation epoch ends.
+
+        This method is executed whenever a validation epoch ends.  Presumably,
+        epochs happen as often as possible.  You want to make this code
+        relatively fast to avoid significative runtime slow-downs.
+
+
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+
+        # summarizes resource usage since the last checkpoint
+        # clears internal buffers and starts accumulating again.
+        self._resource_monitor.checkpoint()
+
+        epoch_time = time.time() - self._start_validation_epoch_time
+        self._to_log["validation_epoch_time"] = epoch_time
+
+        metrics = self._resource_monitor.data
+        if metrics is not None:
+            for metric_name, metric_value in metrics.items():
+                self._to_log[f"validation_{metric_name}"] = float(metric_value)
+        else:
+            logger.warning(
+                "Unable to fetch monitoring information from "
+                "resource monitor. CPU/GPU utilisation will be "
+                "missing."
             )
 
-        assert self.resource_monitor.q.empty()
+        # Compute overall validation losses considering batches and sizes
+        # We disconsider accumulate_grad_batches and assume they were all
+        # of the same size.  This way, the average of averages is the
+        # overall average.
+        for key in sorted(self._validation_epoch_loss.keys()):
+            if key == 0:
+                name = "validation_loss"
+            else:
+                name = f"validation_loss_{key}"
 
-        # Do not log during sanity check as results are not relevant
-        if not trainer.sanity_checking:
-            for metric_name, metric_value in self.resource_monitor.data:
-                self.log(metric_name, float(metric_value))
+            self._to_log[name] = torch.mean(
+                torch.tensor(self._validation_epoch_loss[key][0])
+                * torch.tensor(self._validation_epoch_loss[key][1])
+            ).item()
+
+    def on_validation_batch_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+        outputs: torch.Tensor,
+        batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
+        batch_idx: int,
+        dataloader_idx: int = 0,
+    ) -> None:
+        """Callback to be executed **after** every validation batch ends.
+
+        This method is executed whenever a validation batch ends.  Presumably,
+        batches happen as often as possible.  You want to make this code very
+        fast.  Do not log things to the terminal or the such, or do complicated
+        (lengthy) calculations.
 
-        self.resource_monitor.data = None
+        .. warning::
 
-        self.training_loss = []
-        self.validation_loss = []
+           This is executed **while** you are training.  Be very succint or
+           face the consequences of slow training!
 
 
-class PredictionsWriter(BasePredictionWriter):
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+
+        outputs
+            The outputs of the module's ``training_step``
+
+        batch
+            The data that the training step received
+
+        batch_idx
+            The relative number of the batch
+
+        dataloader_idx
+            Index of the dataloader used during validation.  Use this to figure
+            out which dataset was used for this validation epoch.
+        """
+        size, value = self._validation_epoch_loss.setdefault(
+            dataloader_idx, ([], [])
+        )
+        size.append(batch[0].shape[0])
+        value.append(outputs.item())
+
+    def on_cycle_end(
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+    ) -> None:
+        """Called when the training/validation cycle has ended.
+
+        This function will log all relevant values to the various loggers.  It
+        is supposed to be called by the end of the training cycle (consisting
+        of a training and validation step).
+
+
+        Parameters
+        ----------
+
+        trainer
+            The Lightning trainer object
+
+        pl_module
+            The lightning module that is being trained
+        """
+
+        # collect some final time for the whole training cycle
+        # Note: logging should happen at on_validation_end(), but
+        # apparently you can't log from there
+        overall_cycle_time = time.time() - self._start_training_epoch_time
+        self._to_log["train_cycle_time"] = overall_cycle_time
+        self._to_log["total_time"] = time.time() - self._start_training_time
+        self._to_log["eta"] = overall_cycle_time * (
+            trainer.max_epochs - trainer.current_epoch  # type: ignore
+        )
+
+        # Do not log during sanity check as results are not relevant
+        if not trainer.sanity_checking:
+            for k in sorted(self._to_log.keys()):
+                pl_module.log(k, self._to_log[k])
+            self._to_log = {}
+
+
+class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
     """Lightning callback to write predictions to a file."""
 
-    def __init__(self, output_dir, logfile_fields, write_interval):
+    def __init__(
+        self,
+        output_dir: str | pathlib.Path,
+        logfile_fields: typing.Sequence[str],
+        write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"],
+    ):
         super().__init__(write_interval)
         self.output_dir = output_dir
         self.logfile_fields = logfile_fields
 
     def write_on_epoch_end(
-        self, trainer, pl_module, predictions, batch_indices
-    ):
+        self,
+        trainer: lightning.pytorch.Trainer,
+        pl_module: lightning.pytorch.LightningModule,
+        predictions: typing.Sequence[typing.Any],
+        batch_indices: typing.Sequence[typing.Any] | None,
+    ) -> None:
         for dataloader_idx, dataloader_results in enumerate(predictions):
             dataloader_name = list(
                 trainer.datamodule.predict_dataloader().keys()
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 7643d4ae..ac92d3c1 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -14,28 +14,11 @@ import torch.nn
 
 from ..utils.accelerator import AcceleratorProcessor
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
-from ..utils.save_sh_command import save_sh_command
 from .callbacks import LoggingCallback
 
 logger = logging.getLogger(__name__)
 
 
-def check_gpu(device: str) -> None:
-    """Check the device type and the availability of GPU.
-
-    Parameters
-    ----------
-
-    device : :py:class:`torch.device`
-        device to use
-    """
-    if device == "cuda":
-        # asserts we do have a GPU
-        assert bool(
-            gpu_constants()
-        ), f"Device set to '{device}', but nvidia-smi is not installed"
-
-
 def save_model_summary(
     output_folder: str, model: torch.nn.Module
 ) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
@@ -89,73 +72,16 @@ def static_information_to_csv(static_logfile_name, device, n):
             os.unlink(backup)
         shutil.move(static_logfile_name, backup)
     with open(static_logfile_name, "w", newline="") as f:
-        logdata = cpu_constants()
+        logdata: dict[str, int | float | str] = {}
+        logdata.update(cpu_constants())
         if device == "cuda":
-            logdata += gpu_constants()
-        logdata += (("model_size", n),)
-        logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
+            results = gpu_constants()
+            if results is not None:
+                logdata.update(results)
+        logdata["model_size"] = n
+        logwriter = csv.DictWriter(f, fieldnames=logdata.keys())
         logwriter.writeheader()
-        logwriter.writerow(dict(k for k in logdata))
-
-
-def check_exist_logfile(logfile_name, arguments):
-    """Check existance of logfile (trainlog.csv), If the logfile exist the and
-    the epochs number are still 0, The logfile will be replaced.
-
-    Parameters
-    ----------
-
-    logfile_name : str
-        The logfile_name which is a join between the output_folder and trainlog.csv
-
-    arguments : dict
-        start and end epochs
-    """
-    if arguments["epoch"] == 0 and os.path.exists(logfile_name):
-        backup = logfile_name + "~"
-        if os.path.exists(backup):
-            os.unlink(backup)
-        shutil.move(logfile_name, backup)
-
-
-def create_logfile_fields(valid_loader, extra_valid_loaders, device):
-    """Creation of the logfile fields that will appear in the logfile.
-
-    Parameters
-    ----------
-
-    valid_loader : :py:class:`torch.utils.data.DataLoader`
-        To be used to validate the model and enable automatic checkpointing.
-        If set to ``None``, then do not validate it.
-
-    extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
-        To be used to validate the model, however **does not affect** automatic
-        checkpointing. If set to ``None``, or empty, then does not log anything
-        else.  Otherwise, an extra column with the loss of every dataset in
-        this list is kept on the final training log.
-
-    device : :py:class:`torch.device`
-        device to use
-
-    Returns
-    -------
-
-    logfile_fields: tuple
-        The fields that will appear in trainlog.csv
-    """
-    logfile_fields = (
-        "epoch",
-        "total_time",
-        "eta",
-        "loss",
-        "learning_rate",
-    )
-    if valid_loader is not None:
-        logfile_fields += ("validation_loss",)
-    if extra_valid_loaders:
-        logfile_fields += ("extra_validation_losses",)
-    logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda"))
-    return logfile_fields
+        logwriter.writerow(logdata)
 
 
 def run(
@@ -200,7 +126,7 @@ def run(
 
     accelerator : str
         A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The
-        device can also be specified (gpu:0)
+        device can also be specified (gpu:0).
 
     arguments : dict
         Start and end epochs:
@@ -227,11 +153,7 @@ def run(
     os.makedirs(output_folder, exist_ok=True)
 
     # Save model summary
-    _, n = save_model_summary(output_folder, model)
-
-    save_sh_command(output_folder)
-
-    # save_sh_command(os.path.join(output_folder, "cmd_line_config.txt"))
+    _, no_of_parameters = save_model_summary(output_folder, model)
 
     csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
     tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
@@ -251,7 +173,7 @@ def run(
         save_last=True,
         monitor="validation_loss",
         mode="min",
-        save_on_train_epoch_end=False,
+        save_on_train_epoch_end=True,
         every_n_epochs=checkpoint_period,
     )
 
@@ -260,18 +182,15 @@ def run(
     # write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
     static_information_to_csv(
-        static_logfile_name, accelerator_processor.to_torch(), n
+        static_logfile_name,
+        accelerator_processor.to_torch(),
+        no_of_parameters,
     )
 
-    if accelerator_processor.device is None:
-        devices = "auto"
-    else:
-        devices = accelerator_processor.device
-
     with resource_monitor:
         trainer = lightning.pytorch.Trainer(
             accelerator=accelerator_processor.accelerator,
-            devices=devices,
+            devices=(accelerator_processor.device or "auto"),
             max_epochs=max_epoch,
             accumulate_grad_batches=batch_chunk_count,
             logger=[csv_logger, tensorboard_logger],
diff --git a/src/ptbench/utils/resources.py b/src/ptbench/utils/resources.py
index ebad7794..f7c7f6b9 100644
--- a/src/ptbench/utils/resources.py
+++ b/src/ptbench/utils/resources.py
@@ -6,11 +6,13 @@
 
 import logging
 import multiprocessing
+import multiprocessing.synchronize
 import os
 import queue
 import shutil
 import subprocess
 import time
+import typing
 
 import numpy
 import psutil
@@ -25,7 +27,9 @@ GB = float(2**30)
 """The number of bytes in a gigabyte."""
 
 
-def run_nvidia_smi(query, rename=None):
+def run_nvidia_smi(
+    query: typing.Sequence[str],
+) -> dict[str, str | float] | None:
     """Returns GPU information from query.
 
     For a comprehensive list of options and help, execute ``nvidia-smi
@@ -35,52 +39,43 @@ def run_nvidia_smi(query, rename=None):
     Parameters
     ----------
 
-    query : list
+    query
         A list of query strings as defined by ``nvidia-smi --help-query-gpu``
 
-    rename : :py:class:`list`, Optional
-        A list of keys to yield in the return value for each entry above.  It
-        gives you the opportunity to rewrite some key names for convenience.
-        This list, if provided, must be of the same length as ``query``.
-
 
     Returns
     -------
 
-    data : :py:class:`tuple`, None
-        An ordered dictionary (organized as 2-tuples) containing the queried
-        parameters (``rename`` versions).  If ``nvidia-smi`` is not available,
-        returns ``None``.  Percentage information is left alone,
-        memory information is transformed to gigabytes (floating-point).
+    data
+        A dictionary containing the queried parameters (``rename`` versions).
+        If ``nvidia-smi`` is not available, returns ``None``.  Percentage
+        information is left alone, memory information is transformed to
+        gigabytes (floating-point).
     """
-    if _nvidia_smi is not None:
-        if rename is None:
-            rename = query
-        else:
-            assert len(rename) == len(query)
-
-        # Get GPU information based on GPU ID.
-        values = subprocess.getoutput(
-            "%s --query-gpu=%s --format=csv,noheader --id=%s"
-            % (
-                _nvidia_smi,
-                ",".join(query),
-                os.environ.get("CUDA_VISIBLE_DEVICES"),
-            )
-        )
-        values = [k.strip() for k in values.split(",")]
-        t_values = []
-        for k in values:
-            if k.endswith("%"):
-                t_values.append(float(k[:-1].strip()))
-            elif k.endswith("MiB"):
-                t_values.append(float(k[:-3].strip()) / 1024)
-            else:
-                t_values.append(k)  # unchanged
-        return tuple(zip(rename, t_values))
-
-
-def gpu_constants():
+    if _nvidia_smi is None:
+        return None
+
+    # Gets GPU information, based on a GPU device if that is set. Returns
+    # ordered results.
+    query_str = (
+        f"{_nvidia_smi} --query-gpu={','.join(query)} --format=csv,noheader"
+    )
+    visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
+    if visible_devices:
+        query_str += f" --id={visible_devices}"
+    values = subprocess.getoutput(query_str)
+
+    retval: dict[str, str | float] = {}
+    for i, k in enumerate([k.strip() for k in values.split(",")]):
+        retval[query[i]] = k
+        if k.endswith("%"):
+            retval[query[i]] = float(k[:-1].strip())
+        elif k.endswith("MiB"):
+            retval[query[i]] = float(k[:-3].strip()) / 1024
+    return retval
+
+
+def gpu_constants() -> dict[str, str | int | float] | None:
     """Returns GPU (static) information using nvidia-smi.
 
     See :py:func:`run_nvidia_smi` for operational details.
@@ -90,21 +85,25 @@ def gpu_constants():
 
     data : :py:class:`tuple`, None
         If ``nvidia-smi`` is not available, returns ``None``, otherwise, we
-        return an ordered dictionary (organized as 2-tuples) containing the
-        following ``nvidia-smi`` query information:
+        return a dictionary containing the following ``nvidia-smi`` query
+        information, in this order:
 
         * ``gpu_name``, as ``gpu_name`` (:py:class:`str`)
         * ``driver_version``, as ``gpu_driver_version`` (:py:class:`str`)
         * ``memory.total``, as ``gpu_memory_total`` (transformed to gigabytes,
           :py:class:`float`)
     """
-    return run_nvidia_smi(
-        ("gpu_name", "driver_version", "memory.total"),
-        ("gpu_name", "gpu_driver_version", "gpu_memory_total_GB"),
-    )
+    retval = run_nvidia_smi(("gpu_name", "driver_version", "memory.total"))
+    if retval is None:
+        return retval
+
+    # else, just update with more generic names
+    retval["gpu_driver_version"] = retval.pop("driver_version")
+    retval["gpu_memory_used_GB"] = retval.pop("memory.total")
+    return retval
 
 
-def gpu_log():
+def gpu_log() -> dict[str, float] | None:
     """Returns GPU information about current non-static status using nvidia-
     smi.
 
@@ -113,10 +112,10 @@ def gpu_log():
     Returns
     -------
 
-    data : :py:class:`tuple`, None
+    data
         If ``nvidia-smi`` is not available, returns ``None``, otherwise, we
-        return an ordered dictionary (organized as 2-tuples) containing the
-        following ``nvidia-smi`` query information:
+        return a dictionary containing the following ``nvidia-smi`` query
+        information, in this order:
 
         * ``memory.used``, as ``gpu_memory_used`` (transformed to gigabytes,
           :py:class:`float`)
@@ -127,47 +126,41 @@ def gpu_log():
         * ``utilization.gpu``, as ``gpu_percent``,
           (:py:class:`float`, in percent)
     """
-    retval = run_nvidia_smi(
-        (
-            "memory.total",
-            "memory.used",
-            "memory.free",
-            "utilization.gpu",
-        ),
-        (
-            "gpu_memory_total_GB",
-            "gpu_memory_used_GB",
-            "gpu_memory_free_percent",
-            "gpu_usage_percent",
-        ),
-    )
 
-    # re-compose the output to generate expected values
-    return (
-        retval[1],  # gpu_memory_used
-        retval[2],  # gpu_memory_free
-        ("gpu_memory_percent", 100 * (retval[1][1] / retval[0][1])),
-        retval[3],  # gpu_percent
+    result = run_nvidia_smi(
+        ("memory.total", "memory.used", "memory.free", "utilization.gpu")
     )
 
+    if result is None:
+        return result
 
-def cpu_constants():
+    return {
+        "gpu_memory_used_GB": float(result["memory.used"]),
+        "gpu_memory_free_GB": float(result["memory.free"]),
+        "gpu_memory_percent": 100
+        * float(result["memory.used"])
+        / float(result["memory.total"]),
+        "gpu_percent": float(result["utilization.gpu"]),
+    }
+
+
+def cpu_constants() -> dict[str, int | float]:
     """Returns static CPU information about the current system.
 
     Returns
     -------
 
-    data : tuple
+    data
         An ordered dictionary (organized as 2-tuples) containing these entries:
 
         0. ``cpu_memory_total`` (:py:class:`float`): total memory available,
            in gigabytes
         1. ``cpu_count`` (:py:class:`int`): number of logical CPUs available
     """
-    return (
-        ("cpu_memory_total_GB", psutil.virtual_memory().total / GB),
-        ("cpu_count", psutil.cpu_count(logical=True)),
-    )
+    return {
+        "cpu_memory_total_GB": psutil.virtual_memory().total / GB,
+        "cpu_count": psutil.cpu_count(logical=True),
+    }
 
 
 class CPULogger:
@@ -176,24 +169,24 @@ class CPULogger:
     Parameters
     ----------
 
-    pid : :py:class:`int`, Optional
+    pid
         Process identifier of the main process (parent process) to observe
     """
 
-    def __init__(self, pid=None):
+    def __init__(self, pid: int | None = None):
         this = psutil.Process(pid=pid)
         self.cluster = [this] + this.children(recursive=True)
         # touch cpu_percent() at least once for all processes in the cluster
         [k.cpu_percent(interval=None) for k in self.cluster]
 
-    def log(self):
-        """Returns current process cluster information.
+    def log(self) -> dict[str, int | float]:
+        """Returns current process cluster iformation.
 
         Returns
         -------
 
-        data : tuple
-            An ordered dictionary (organized as 2-tuples) containing these entries:
+        data
+            An ordered dictionary containing these entries:
 
             0. ``cpu_memory_used`` (:py:class:`float`): total memory used from
                the system, in gigabytes
@@ -244,14 +237,14 @@ class CPULogger:
                 # it is too late to update any intermediate list
                 # at this point, but ensures to update counts later on
                 gone.add(k)
-        return (
-            ("cpu_memory_used_GB", psutil.virtual_memory().used / GB),
-            ("cpu_rss_GB", sum([k.rss for k in memory_info]) / GB),
-            ("cpu_vms_GB", sum([k.vms for k in memory_info]) / GB),
-            ("cpu_percent", sum(cpu_percent)),
-            ("cpu_processes", len(self.cluster) - len(gone)),
-            ("cpu_open_files", sum(open_files)),
-        )
+        return {
+            "cpu_memory_used_GB": psutil.virtual_memory().used / GB,
+            "cpu_rss_GB": sum([k.rss for k in memory_info]) / GB,
+            "cpu_vms_GB": sum([k.vms for k in memory_info]) / GB,
+            "cpu_percent": sum(cpu_percent),
+            "cpu_processes": len(self.cluster) - len(gone),
+            "cpu_open_files": sum(open_files),
+        }
 
 
 class _InformationGatherer:
@@ -260,73 +253,85 @@ class _InformationGatherer:
     Parameters
     ----------
 
-    has_gpu : bool
+    has_gpu
         A flag indicating if we have a GPU installed on the platform or not
 
-    main_pid : int
+    main_pid
         The main process identifier to monitor
 
-    logger : logging.Logger
+    logger
         A logger to be used for logging messages
     """
 
-    def __init__(self, has_gpu, main_pid, logger):
+    def __init__(
+        self, has_gpu: bool, main_pid: int | None, logger: logging.Logger
+    ):
+        self.logger: logging.Logger = logger
         self.cpu_logger = CPULogger(main_pid)
-        self.keys = [k[0] for k in self.cpu_logger.log()]
-        self.cpu_keys_len = len(self.keys)
-        self.has_gpu = has_gpu
-        self.logger = logger
+        keys: list[str] = list(self.cpu_logger.log().keys())
+        self.has_gpu: bool = has_gpu
         if self.has_gpu:
-            self.keys += [k[0] for k in gpu_log()]
-        self.data = [[] for _ in self.keys]
+            example = gpu_log()
+            if example is not None:
+                keys += list(example.keys())
+        self.data: dict[str, list[int | float]] = {k: [] for k in keys}
 
-    def acc(self):
+    def acc(self) -> None:
         """Accumulates another measurement."""
-        for i, k in enumerate(self.cpu_logger.log()):
-            self.data[i].append(k[1])
+        for k, v in self.cpu_logger.log().items():
+            self.data[k].append(v)
         if self.has_gpu:
-            for i, k in enumerate(gpu_log()):
-                self.data[i + self.cpu_keys_len].append(k[1])
+            sample = gpu_log()
+            if sample is not None:
+                for k, v in sample.items():
+                    self.data[k].append(v)
 
-    def clear(self):
+    def clear(self) -> None:
         """Clears accumulated data."""
-        self.data = [[] for _ in self.keys]
+        for k in self.data.keys():
+            self.data[k] = []
 
-    def summary(self):
+    def summary(self) -> dict[str, list[int | float]]:
         """Returns the current data."""
-        if len(self.data[0]) == 0:
+        if len(next(iter(self.data.values()))) == 0:
             self.logger.error("CPU/GPU logger was not able to collect any data")
-        retval = []
-        for k, values in zip(self.keys, self.data):
-            retval.append((k, values))
-        return tuple(retval)
+        return self.data
 
 
 def _monitor_worker(
-    interval, has_gpu, main_pid, stop, summary_event, queue, logging_level
+    interval: int | float,
+    has_gpu: bool,
+    main_pid: int,
+    stop: multiprocessing.synchronize.Event,
+    summary_event: multiprocessing.synchronize.Event,
+    queue: queue.Queue,
+    logging_level: int,
 ):
     """A monitoring worker that measures resources and returns lists.
 
     Parameters
     ==========
 
-    interval : int, float
+    interval
         Number of seconds to wait between each measurement (maybe a floating
         point number as accepted by :py:func:`time.sleep`)
 
-    has_gpu : bool
+    has_gpu
         A flag indicating if we have a GPU installed on the platform or not
 
-    main_pid : int
+    main_pid
         The main process identifier to monitor
 
-    stop : :py:class:`multiprocessing.Event`
-        Indicates if we should continue running or stop
+    stop
+        Event that indicates if we should continue running or stop
 
-    queue : :py:class:`queue.Queue`
+    summary_event
+        Event that indicates if we should produce a summary
+
+    queue
         A queue, to send monitoring information back to the spawner
 
-    logging_level: int
+    logging_level
         The logging level to use for logging from launched processes
     """
     logger = multiprocessing.log_to_stderr(level=logging_level)
@@ -343,9 +348,9 @@ def _monitor_worker(
 
             time.sleep(interval)
         except Exception:
-            logger.warning(
-                "Iterative CPU/GPU logging did not work properly " "this once",
-                exc_info=True,
+            logger.exception(
+                "Iterative CPU/GPU logging did not work properly."
+                " Exception follows.  Retrying..."
             )
             time.sleep(0.5)  # wait half a second, and try again!
 
@@ -356,27 +361,35 @@ class ResourceMonitor:
     Parameters
     ----------
 
-    interval : int, float
+    interval
         Number of seconds to wait between each measurement (maybe a floating
         point number as accepted by :py:func:`time.sleep`)
 
-    has_gpu : bool
+    has_gpu
         A flag indicating if we have a GPU installed on the platform or not
 
-    main_pid : int
+    main_pid
         The main process identifier to monitor
 
-    logging_level: int
+    logging_level
         The logging level to use for logging from launched processes
     """
 
-    def __init__(self, interval, has_gpu, main_pid, logging_level):
+    def __init__(
+        self,
+        interval: int | float,
+        has_gpu: bool,
+        main_pid: int,
+        logging_level: int,
+    ):
         self.interval = interval
         self.has_gpu = has_gpu
         self.main_pid = main_pid
         self.stop_event = multiprocessing.Event()
         self.summary_event = multiprocessing.Event()
-        self.q = multiprocessing.Queue()
+        self.q: multiprocessing.Queue[
+            dict[str, list[int | float]]
+        ] = multiprocessing.Queue()
         self.logging_level = logging_level
 
         self.monitor = multiprocessing.Process(
@@ -393,23 +406,23 @@ class ResourceMonitor:
             ),
         )
 
-        self.data = None
-
-    @staticmethod
-    def monitored_keys(has_gpu):
-        return _InformationGatherer(has_gpu, None, logger).keys
+        self.data: dict[str, int | float] | None = None
 
-    def __enter__(self):
+    def __enter__(self) -> None:
         """Starts the monitoring process."""
         self.monitor.start()
 
-    def trigger_summary(self):
+    def checkpoint(self) -> None:
+        """Forces the monitoring process to yield data and clear the internal
+        accumlator."""
         self.summary_event.set()
 
         try:
-            data = self.q.get(timeout=2 * self.interval)
+            data: dict[str, list[int | float]] = self.q.get(
+                timeout=2 * self.interval
+            )
         except queue.Empty:
-            logger.warn(
+            logger.warning(
                 f"CPU/GPU resource monitor did not provide anything when "
                 f"joined (even after a {2*self.interval}-second timeout - "
                 f"this is normally due to exceptions on the monitoring process. "
@@ -417,19 +430,18 @@ class ResourceMonitor:
             )
             self.data = None
         else:
-            # summarize the returned data by creating means
-            summary = []
-            for k, values in data:
+            # summarize the returned data by creating averages
+            self.data = {}
+            for k, values in data.items():
                 if values:
                     if k in ("cpu_processes", "cpu_open_files"):
-                        summary.append((k, numpy.max(values)))
+                        self.data[k] = numpy.max(values)
                     else:
-                        summary.append((k, numpy.mean(values)))
+                        self.data[k] = float(numpy.mean(values))
                 else:
-                    summary.append((k, 0.0))
-            self.data = tuple(summary)
+                    self.data[k] = 0.0
 
-    def __exit__(self, *exc):
+    def __exit__(self, *_) -> None:
         """Stops the monitoring process and returns the summary of
         observations."""
 
-- 
GitLab