Skip to content
Snippets Groups Projects
Commit 427c1cf1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[ptbench.engine] Simplified, documented and created type hints for the...

[ptbench.engine] Simplified, documented and created type hints for the ``callbacks`` module; Added type hints to the resource module
parent d576423f
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import csv
import logging
import os
import pathlib
import time
import typing
from collections import defaultdict
import lightning.pytorch
import lightning.pytorch.callbacks
import torch
import numpy
from ..utils.resources import ResourceMonitor
from lightning.pytorch import Callback
from lightning.pytorch.callbacks import BasePredictionWriter
logger = logging.getLogger(__name__)
# This ensures CSVLogger logs training and evaluation metrics on the same line
# CSVLogger only accepts numerical values, not strings
class LoggingCallback(Callback):
"""Lightning callback to log various training metrics and device
information."""
class LoggingCallback(lightning.pytorch.Callback):
"""Callback to log various training metrics and device information.
def __init__(self, resource_monitor):
super().__init__()
self.training_loss = []
self.validation_loss = []
self.extra_validation_loss = defaultdict(list)
self.start_training_time = 0
self.start_epoch_time = 0
It ensures CSVLogger logs training and evaluation metrics on the same line
Note that a CSVLogger only accepts numerical values, and not strings.
self.resource_monitor = resource_monitor
self.max_queue_retries = 2
def on_train_start(self, trainer, pl_module):
self.start_training_time = time.time()
Parameters
----------
def on_train_epoch_start(self, trainer, pl_module):
self.start_epoch_time = time.time()
resource_monitor
A monitor that watches resource usage (CPU/GPU) in a separate process
and totally asynchronously with the code execution.
"""
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.training_loss.append(outputs["loss"].item())
def __init__(self, resource_monitor: ResourceMonitor):
super().__init__()
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx
# lists of number of samples/batch and average losses
# - we use this later to compute overall epoch losses
self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
self._validation_epoch_loss: dict[
int, tuple[list[int], list[float]]
] = {}
# timers
self._start_training_time = 0.0
self._start_training_epoch_time = 0.0
self._start_validation_epoch_time = 0.0
# log accumulators for a single flush at each training cycle
self._to_log: dict[str, float] = {}
# helpers for CPU and GPU utilisation
self._resource_monitor = resource_monitor
self._max_queue_retries = 2
def on_train_start(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
):
self.validation_loss.append(outputs["validation_loss"].item())
"""Callback to be executed **before** the whole training starts.
if len(outputs) > 1:
extra_validation_keys = outputs.keys().remove("validation_loss")
for extra_validation_loss_key in extra_validation_keys:
self.extra_validation_loss[extra_validation_loss_key].append(
outputs[extra_validation_loss_key]
)
This method is executed whenever you *start* training a module.
def on_validation_epoch_end(self, trainer, pl_module):
self.resource_monitor.trigger_summary()
self.epoch_time = time.time() - self.start_epoch_time
eta_seconds = self.epoch_time * (
trainer.max_epochs - trainer.current_epoch
)
current_time = time.time() - self.start_training_time
Parameters
---------
def _compute_batch_loss(losses, num_chunks):
# When accumulating gradients, partial losses need to be summed per batch before averaging
if num_chunks != 1:
# The loss we get is scaled by the number of accumulation steps
losses = numpy.multiply(losses, num_chunks)
trainer
The Lightning trainer object
if len(losses) % num_chunks > 0:
num_splits = (len(losses) // num_chunks) + 1
else:
num_splits = len(losses) // num_chunks
pl_module
The lightning module that is being trained
"""
self._start_training_time = time.time()
batched_losses = numpy.array_split(losses, num_splits)
def on_train_epoch_start(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Callback to be executed **before** every training batch starts.
summed_batch_losses = []
This method is executed whenever a training batch starts. Presumably,
batches happen as often as possible. You want to make this code very
fast. Do not log things to the terminal or the such, or do complicated
(lengthy) calculations.
for b in batched_losses:
summed_batch_losses.append(numpy.average(b))
.. warning::
return summed_batch_losses
This is executed **while** you are training. Be very succint or
face the consequences of slow training!
# No gradient accumulation, we already have the batch losses
else:
return losses
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
# We get partial loses when using gradient accumulation
self.training_loss = _compute_batch_loss(
self.training_loss, trainer.accumulate_grad_batches
)
self.validation_loss = _compute_batch_loss(
self.validation_loss, trainer.accumulate_grad_batches
)
Parameters
---------
self.log("total_time", current_time)
self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss))
self.log("learning_rate", pl_module.optimizer_configs["lr"])
self.log("validation_loss", numpy.average(self.validation_loss))
if len(self.extra_validation_loss) > 0:
for (
extra_valid_loss_key,
extra_valid_loss_values,
) in self.extra_validation_loss.items:
self.log(
extra_valid_loss_key,
numpy.average(extra_valid_loss_values),
)
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
self._start_training_epoch_time = time.time()
self._training_epoch_loss = ([], [])
queue_retries = 0
# In case the resource monitor takes longer to fetch data from the queue, we wait
# Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue
while (
self.resource_monitor.data is None
and queue_retries < self.max_queue_retries
):
queue_retries = queue_retries + 1
print(
f"Monitor queue is empty, retrying in {self.resource_monitor.interval}s"
def on_train_epoch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
):
"""Callback to be executed **after** every training epoch ends.
This method is executed whenever a training epoch ends. Presumably,
epochs happen as often as possible. You want to make this code
relatively fast to avoid significative runtime slow-downs.
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# summarizes resource usage since the last checkpoint
# clears internal buffers and starts accumulating again.
self._resource_monitor.checkpoint()
# evaluates this training epoch total time, and log it
epoch_time = time.time() - self._start_training_epoch_time
# Compute overall training loss considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all of
# the same size. This way, the average of averages is the overall
# average.
self._to_log["train_loss"] = torch.mean(
torch.tensor(self._training_epoch_loss[0])
* torch.tensor(self._training_epoch_loss[1])
).item()
self._to_log["train_epoch_time"] = epoch_time
self._to_log["learning_rate"] = pl_module.optimizers().defaults["lr"]
metrics = self._resource_monitor.data
if metrics is not None:
for metric_name, metric_value in metrics.items():
self._to_log[f"train_{metric_name}"] = float(metric_value)
else:
logger.warning(
"Unable to fetch monitoring information from "
"resource monitor. CPU/GPU utilisation will be "
"missing."
)
time.sleep(self.resource_monitor.interval)
if queue_retries >= self.max_queue_retries:
print(
f"Unable to fetch monitoring information from queue after {queue_retries} retries"
# if no validation dataloaders, complete cycle by the end of the
# training epoch, by logging all values to the logger
self.on_cycle_end(trainer, pl_module)
def on_train_batch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
outputs: typing.Mapping[str, torch.Tensor],
batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
batch_idx: int,
) -> None:
"""Callback to be executed **after** every training batch ends.
This method is executed whenever a training batch ends. Presumably,
batches happen as often as possible. You want to make this code very
fast. Do not log things to the terminal or the such, or do complicated
(lengthy) calculations.
.. warning::
This is executed **while** you are training. Be very succint or
face the consequences of slow training!
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
outputs
The outputs of the module's ``training_step``
batch
The data that the training step received
batch_idx
The relative number of the batch
"""
self._training_epoch_loss[0].append(batch[0].shape[0])
self._training_epoch_loss[1].append(outputs["loss"].item())
def on_validation_epoch_start(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Callback to be executed **before** every validation batch starts.
This method is executed whenever a validation batch starts. Presumably,
batches happen as often as possible. You want to make this code very
fast. Do not log things to the terminal or the such, or do complicated
(lengthy) calculations.
.. warning::
This is executed **while** you are training. Be very succint or
face the consequences of slow training!
Parameters
---------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
self._start_validation_epoch_time = time.time()
self._validation_epoch_loss = {}
def on_validation_epoch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Callback to be executed **after** every validation epoch ends.
This method is executed whenever a validation epoch ends. Presumably,
epochs happen as often as possible. You want to make this code
relatively fast to avoid significative runtime slow-downs.
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# summarizes resource usage since the last checkpoint
# clears internal buffers and starts accumulating again.
self._resource_monitor.checkpoint()
epoch_time = time.time() - self._start_validation_epoch_time
self._to_log["validation_epoch_time"] = epoch_time
metrics = self._resource_monitor.data
if metrics is not None:
for metric_name, metric_value in metrics.items():
self._to_log[f"validation_{metric_name}"] = float(metric_value)
else:
logger.warning(
"Unable to fetch monitoring information from "
"resource monitor. CPU/GPU utilisation will be "
"missing."
)
assert self.resource_monitor.q.empty()
# Compute overall validation losses considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all
# of the same size. This way, the average of averages is the
# overall average.
for key in sorted(self._validation_epoch_loss.keys()):
if key == 0:
name = "validation_loss"
else:
name = f"validation_loss_{key}"
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
for metric_name, metric_value in self.resource_monitor.data:
self.log(metric_name, float(metric_value))
self._to_log[name] = torch.mean(
torch.tensor(self._validation_epoch_loss[key][0])
* torch.tensor(self._validation_epoch_loss[key][1])
).item()
def on_validation_batch_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
outputs: torch.Tensor,
batch: tuple[torch.Tensor, typing.Mapping[str, torch.Tensor]],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Callback to be executed **after** every validation batch ends.
This method is executed whenever a validation batch ends. Presumably,
batches happen as often as possible. You want to make this code very
fast. Do not log things to the terminal or the such, or do complicated
(lengthy) calculations.
self.resource_monitor.data = None
.. warning::
self.training_loss = []
self.validation_loss = []
This is executed **while** you are training. Be very succint or
face the consequences of slow training!
class PredictionsWriter(BasePredictionWriter):
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
outputs
The outputs of the module's ``training_step``
batch
The data that the training step received
batch_idx
The relative number of the batch
dataloader_idx
Index of the dataloader used during validation. Use this to figure
out which dataset was used for this validation epoch.
"""
size, value = self._validation_epoch_loss.setdefault(
dataloader_idx, ([], [])
)
size.append(batch[0].shape[0])
value.append(outputs.item())
def on_cycle_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Called when the training/validation cycle has ended.
This function will log all relevant values to the various loggers. It
is supposed to be called by the end of the training cycle (consisting
of a training and validation step).
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# collect some final time for the whole training cycle
# Note: logging should happen at on_validation_end(), but
# apparently you can't log from there
overall_cycle_time = time.time() - self._start_training_epoch_time
self._to_log["train_cycle_time"] = overall_cycle_time
self._to_log["total_time"] = time.time() - self._start_training_time
self._to_log["eta"] = overall_cycle_time * (
trainer.max_epochs - trainer.current_epoch # type: ignore
)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
for k in sorted(self._to_log.keys()):
pl_module.log(k, self._to_log[k])
self._to_log = {}
class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter):
"""Lightning callback to write predictions to a file."""
def __init__(self, output_dir, logfile_fields, write_interval):
def __init__(
self,
output_dir: str | pathlib.Path,
logfile_fields: typing.Sequence[str],
write_interval: typing.Literal["batch", "epoch", "batch_and_epoch"],
):
super().__init__(write_interval)
self.output_dir = output_dir
self.logfile_fields = logfile_fields
def write_on_epoch_end(
self, trainer, pl_module, predictions, batch_indices
):
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
predictions: typing.Sequence[typing.Any],
batch_indices: typing.Sequence[typing.Any] | None,
) -> None:
for dataloader_idx, dataloader_results in enumerate(predictions):
dataloader_name = list(
trainer.datamodule.predict_dataloader().keys()
......
......@@ -14,28 +14,11 @@ import torch.nn
from ..utils.accelerator import AcceleratorProcessor
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from ..utils.save_sh_command import save_sh_command
from .callbacks import LoggingCallback
logger = logging.getLogger(__name__)
def check_gpu(device: str) -> None:
"""Check the device type and the availability of GPU.
Parameters
----------
device : :py:class:`torch.device`
device to use
"""
if device == "cuda":
# asserts we do have a GPU
assert bool(
gpu_constants()
), f"Device set to '{device}', but nvidia-smi is not installed"
def save_model_summary(
output_folder: str, model: torch.nn.Module
) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
......@@ -89,73 +72,16 @@ def static_information_to_csv(static_logfile_name, device, n):
os.unlink(backup)
shutil.move(static_logfile_name, backup)
with open(static_logfile_name, "w", newline="") as f:
logdata = cpu_constants()
logdata: dict[str, int | float | str] = {}
logdata.update(cpu_constants())
if device == "cuda":
logdata += gpu_constants()
logdata += (("model_size", n),)
logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
results = gpu_constants()
if results is not None:
logdata.update(results)
logdata["model_size"] = n
logwriter = csv.DictWriter(f, fieldnames=logdata.keys())
logwriter.writeheader()
logwriter.writerow(dict(k for k in logdata))
def check_exist_logfile(logfile_name, arguments):
"""Check existance of logfile (trainlog.csv), If the logfile exist the and
the epochs number are still 0, The logfile will be replaced.
Parameters
----------
logfile_name : str
The logfile_name which is a join between the output_folder and trainlog.csv
arguments : dict
start and end epochs
"""
if arguments["epoch"] == 0 and os.path.exists(logfile_name):
backup = logfile_name + "~"
if os.path.exists(backup):
os.unlink(backup)
shutil.move(logfile_name, backup)
def create_logfile_fields(valid_loader, extra_valid_loaders, device):
"""Creation of the logfile fields that will appear in the logfile.
Parameters
----------
valid_loader : :py:class:`torch.utils.data.DataLoader`
To be used to validate the model and enable automatic checkpointing.
If set to ``None``, then do not validate it.
extra_valid_loaders : :py:class:`list` of :py:class:`torch.utils.data.DataLoader`
To be used to validate the model, however **does not affect** automatic
checkpointing. If set to ``None``, or empty, then does not log anything
else. Otherwise, an extra column with the loss of every dataset in
this list is kept on the final training log.
device : :py:class:`torch.device`
device to use
Returns
-------
logfile_fields: tuple
The fields that will appear in trainlog.csv
"""
logfile_fields = (
"epoch",
"total_time",
"eta",
"loss",
"learning_rate",
)
if valid_loader is not None:
logfile_fields += ("validation_loss",)
if extra_valid_loaders:
logfile_fields += ("extra_validation_losses",)
logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda"))
return logfile_fields
logwriter.writerow(logdata)
def run(
......@@ -200,7 +126,7 @@ def run(
accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The
device can also be specified (gpu:0)
device can also be specified (gpu:0).
arguments : dict
Start and end epochs:
......@@ -227,11 +153,7 @@ def run(
os.makedirs(output_folder, exist_ok=True)
# Save model summary
_, n = save_model_summary(output_folder, model)
save_sh_command(output_folder)
# save_sh_command(os.path.join(output_folder, "cmd_line_config.txt"))
_, no_of_parameters = save_model_summary(output_folder, model)
csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
......@@ -251,7 +173,7 @@ def run(
save_last=True,
monitor="validation_loss",
mode="min",
save_on_train_epoch_end=False,
save_on_train_epoch_end=True,
every_n_epochs=checkpoint_period,
)
......@@ -260,18 +182,15 @@ def run(
# write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv")
static_information_to_csv(
static_logfile_name, accelerator_processor.to_torch(), n
static_logfile_name,
accelerator_processor.to_torch(),
no_of_parameters,
)
if accelerator_processor.device is None:
devices = "auto"
else:
devices = accelerator_processor.device
with resource_monitor:
trainer = lightning.pytorch.Trainer(
accelerator=accelerator_processor.accelerator,
devices=devices,
devices=(accelerator_processor.device or "auto"),
max_epochs=max_epoch,
accumulate_grad_batches=batch_chunk_count,
logger=[csv_logger, tensorboard_logger],
......
......@@ -6,11 +6,13 @@
import logging
import multiprocessing
import multiprocessing.synchronize
import os
import queue
import shutil
import subprocess
import time
import typing
import numpy
import psutil
......@@ -25,7 +27,9 @@ GB = float(2**30)
"""The number of bytes in a gigabyte."""
def run_nvidia_smi(query, rename=None):
def run_nvidia_smi(
query: typing.Sequence[str],
) -> dict[str, str | float] | None:
"""Returns GPU information from query.
For a comprehensive list of options and help, execute ``nvidia-smi
......@@ -35,52 +39,43 @@ def run_nvidia_smi(query, rename=None):
Parameters
----------
query : list
query
A list of query strings as defined by ``nvidia-smi --help-query-gpu``
rename : :py:class:`list`, Optional
A list of keys to yield in the return value for each entry above. It
gives you the opportunity to rewrite some key names for convenience.
This list, if provided, must be of the same length as ``query``.
Returns
-------
data : :py:class:`tuple`, None
An ordered dictionary (organized as 2-tuples) containing the queried
parameters (``rename`` versions). If ``nvidia-smi`` is not available,
returns ``None``. Percentage information is left alone,
memory information is transformed to gigabytes (floating-point).
data
A dictionary containing the queried parameters (``rename`` versions).
If ``nvidia-smi`` is not available, returns ``None``. Percentage
information is left alone, memory information is transformed to
gigabytes (floating-point).
"""
if _nvidia_smi is not None:
if rename is None:
rename = query
else:
assert len(rename) == len(query)
# Get GPU information based on GPU ID.
values = subprocess.getoutput(
"%s --query-gpu=%s --format=csv,noheader --id=%s"
% (
_nvidia_smi,
",".join(query),
os.environ.get("CUDA_VISIBLE_DEVICES"),
)
)
values = [k.strip() for k in values.split(",")]
t_values = []
for k in values:
if k.endswith("%"):
t_values.append(float(k[:-1].strip()))
elif k.endswith("MiB"):
t_values.append(float(k[:-3].strip()) / 1024)
else:
t_values.append(k) # unchanged
return tuple(zip(rename, t_values))
def gpu_constants():
if _nvidia_smi is None:
return None
# Gets GPU information, based on a GPU device if that is set. Returns
# ordered results.
query_str = (
f"{_nvidia_smi} --query-gpu={','.join(query)} --format=csv,noheader"
)
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible_devices:
query_str += f" --id={visible_devices}"
values = subprocess.getoutput(query_str)
retval: dict[str, str | float] = {}
for i, k in enumerate([k.strip() for k in values.split(",")]):
retval[query[i]] = k
if k.endswith("%"):
retval[query[i]] = float(k[:-1].strip())
elif k.endswith("MiB"):
retval[query[i]] = float(k[:-3].strip()) / 1024
return retval
def gpu_constants() -> dict[str, str | int | float] | None:
"""Returns GPU (static) information using nvidia-smi.
See :py:func:`run_nvidia_smi` for operational details.
......@@ -90,21 +85,25 @@ def gpu_constants():
data : :py:class:`tuple`, None
If ``nvidia-smi`` is not available, returns ``None``, otherwise, we
return an ordered dictionary (organized as 2-tuples) containing the
following ``nvidia-smi`` query information:
return a dictionary containing the following ``nvidia-smi`` query
information, in this order:
* ``gpu_name``, as ``gpu_name`` (:py:class:`str`)
* ``driver_version``, as ``gpu_driver_version`` (:py:class:`str`)
* ``memory.total``, as ``gpu_memory_total`` (transformed to gigabytes,
:py:class:`float`)
"""
return run_nvidia_smi(
("gpu_name", "driver_version", "memory.total"),
("gpu_name", "gpu_driver_version", "gpu_memory_total_GB"),
)
retval = run_nvidia_smi(("gpu_name", "driver_version", "memory.total"))
if retval is None:
return retval
# else, just update with more generic names
retval["gpu_driver_version"] = retval.pop("driver_version")
retval["gpu_memory_used_GB"] = retval.pop("memory.total")
return retval
def gpu_log():
def gpu_log() -> dict[str, float] | None:
"""Returns GPU information about current non-static status using nvidia-
smi.
......@@ -113,10 +112,10 @@ def gpu_log():
Returns
-------
data : :py:class:`tuple`, None
data
If ``nvidia-smi`` is not available, returns ``None``, otherwise, we
return an ordered dictionary (organized as 2-tuples) containing the
following ``nvidia-smi`` query information:
return a dictionary containing the following ``nvidia-smi`` query
information, in this order:
* ``memory.used``, as ``gpu_memory_used`` (transformed to gigabytes,
:py:class:`float`)
......@@ -127,47 +126,41 @@ def gpu_log():
* ``utilization.gpu``, as ``gpu_percent``,
(:py:class:`float`, in percent)
"""
retval = run_nvidia_smi(
(
"memory.total",
"memory.used",
"memory.free",
"utilization.gpu",
),
(
"gpu_memory_total_GB",
"gpu_memory_used_GB",
"gpu_memory_free_percent",
"gpu_usage_percent",
),
)
# re-compose the output to generate expected values
return (
retval[1], # gpu_memory_used
retval[2], # gpu_memory_free
("gpu_memory_percent", 100 * (retval[1][1] / retval[0][1])),
retval[3], # gpu_percent
result = run_nvidia_smi(
("memory.total", "memory.used", "memory.free", "utilization.gpu")
)
if result is None:
return result
def cpu_constants():
return {
"gpu_memory_used_GB": float(result["memory.used"]),
"gpu_memory_free_GB": float(result["memory.free"]),
"gpu_memory_percent": 100
* float(result["memory.used"])
/ float(result["memory.total"]),
"gpu_percent": float(result["utilization.gpu"]),
}
def cpu_constants() -> dict[str, int | float]:
"""Returns static CPU information about the current system.
Returns
-------
data : tuple
data
An ordered dictionary (organized as 2-tuples) containing these entries:
0. ``cpu_memory_total`` (:py:class:`float`): total memory available,
in gigabytes
1. ``cpu_count`` (:py:class:`int`): number of logical CPUs available
"""
return (
("cpu_memory_total_GB", psutil.virtual_memory().total / GB),
("cpu_count", psutil.cpu_count(logical=True)),
)
return {
"cpu_memory_total_GB": psutil.virtual_memory().total / GB,
"cpu_count": psutil.cpu_count(logical=True),
}
class CPULogger:
......@@ -176,24 +169,24 @@ class CPULogger:
Parameters
----------
pid : :py:class:`int`, Optional
pid
Process identifier of the main process (parent process) to observe
"""
def __init__(self, pid=None):
def __init__(self, pid: int | None = None):
this = psutil.Process(pid=pid)
self.cluster = [this] + this.children(recursive=True)
# touch cpu_percent() at least once for all processes in the cluster
[k.cpu_percent(interval=None) for k in self.cluster]
def log(self):
"""Returns current process cluster information.
def log(self) -> dict[str, int | float]:
"""Returns current process cluster iformation.
Returns
-------
data : tuple
An ordered dictionary (organized as 2-tuples) containing these entries:
data
An ordered dictionary containing these entries:
0. ``cpu_memory_used`` (:py:class:`float`): total memory used from
the system, in gigabytes
......@@ -244,14 +237,14 @@ class CPULogger:
# it is too late to update any intermediate list
# at this point, but ensures to update counts later on
gone.add(k)
return (
("cpu_memory_used_GB", psutil.virtual_memory().used / GB),
("cpu_rss_GB", sum([k.rss for k in memory_info]) / GB),
("cpu_vms_GB", sum([k.vms for k in memory_info]) / GB),
("cpu_percent", sum(cpu_percent)),
("cpu_processes", len(self.cluster) - len(gone)),
("cpu_open_files", sum(open_files)),
)
return {
"cpu_memory_used_GB": psutil.virtual_memory().used / GB,
"cpu_rss_GB": sum([k.rss for k in memory_info]) / GB,
"cpu_vms_GB": sum([k.vms for k in memory_info]) / GB,
"cpu_percent": sum(cpu_percent),
"cpu_processes": len(self.cluster) - len(gone),
"cpu_open_files": sum(open_files),
}
class _InformationGatherer:
......@@ -260,73 +253,85 @@ class _InformationGatherer:
Parameters
----------
has_gpu : bool
has_gpu
A flag indicating if we have a GPU installed on the platform or not
main_pid : int
main_pid
The main process identifier to monitor
logger : logging.Logger
logger
A logger to be used for logging messages
"""
def __init__(self, has_gpu, main_pid, logger):
def __init__(
self, has_gpu: bool, main_pid: int | None, logger: logging.Logger
):
self.logger: logging.Logger = logger
self.cpu_logger = CPULogger(main_pid)
self.keys = [k[0] for k in self.cpu_logger.log()]
self.cpu_keys_len = len(self.keys)
self.has_gpu = has_gpu
self.logger = logger
keys: list[str] = list(self.cpu_logger.log().keys())
self.has_gpu: bool = has_gpu
if self.has_gpu:
self.keys += [k[0] for k in gpu_log()]
self.data = [[] for _ in self.keys]
example = gpu_log()
if example is not None:
keys += list(example.keys())
self.data: dict[str, list[int | float]] = {k: [] for k in keys}
def acc(self):
def acc(self) -> None:
"""Accumulates another measurement."""
for i, k in enumerate(self.cpu_logger.log()):
self.data[i].append(k[1])
for k, v in self.cpu_logger.log().items():
self.data[k].append(v)
if self.has_gpu:
for i, k in enumerate(gpu_log()):
self.data[i + self.cpu_keys_len].append(k[1])
sample = gpu_log()
if sample is not None:
for k, v in sample.items():
self.data[k].append(v)
def clear(self):
def clear(self) -> None:
"""Clears accumulated data."""
self.data = [[] for _ in self.keys]
for k in self.data.keys():
self.data[k] = []
def summary(self):
def summary(self) -> dict[str, list[int | float]]:
"""Returns the current data."""
if len(self.data[0]) == 0:
if len(next(iter(self.data.values()))) == 0:
self.logger.error("CPU/GPU logger was not able to collect any data")
retval = []
for k, values in zip(self.keys, self.data):
retval.append((k, values))
return tuple(retval)
return self.data
def _monitor_worker(
interval, has_gpu, main_pid, stop, summary_event, queue, logging_level
interval: int | float,
has_gpu: bool,
main_pid: int,
stop: multiprocessing.synchronize.Event,
summary_event: multiprocessing.synchronize.Event,
queue: queue.Queue,
logging_level: int,
):
"""A monitoring worker that measures resources and returns lists.
Parameters
==========
interval : int, float
interval
Number of seconds to wait between each measurement (maybe a floating
point number as accepted by :py:func:`time.sleep`)
has_gpu : bool
has_gpu
A flag indicating if we have a GPU installed on the platform or not
main_pid : int
main_pid
The main process identifier to monitor
stop : :py:class:`multiprocessing.Event`
Indicates if we should continue running or stop
stop
Event that indicates if we should continue running or stop
queue : :py:class:`queue.Queue`
summary_event
Event that indicates if we should produce a summary
queue
A queue, to send monitoring information back to the spawner
logging_level: int
logging_level
The logging level to use for logging from launched processes
"""
logger = multiprocessing.log_to_stderr(level=logging_level)
......@@ -343,9 +348,9 @@ def _monitor_worker(
time.sleep(interval)
except Exception:
logger.warning(
"Iterative CPU/GPU logging did not work properly " "this once",
exc_info=True,
logger.exception(
"Iterative CPU/GPU logging did not work properly."
" Exception follows. Retrying..."
)
time.sleep(0.5) # wait half a second, and try again!
......@@ -356,27 +361,35 @@ class ResourceMonitor:
Parameters
----------
interval : int, float
interval
Number of seconds to wait between each measurement (maybe a floating
point number as accepted by :py:func:`time.sleep`)
has_gpu : bool
has_gpu
A flag indicating if we have a GPU installed on the platform or not
main_pid : int
main_pid
The main process identifier to monitor
logging_level: int
logging_level
The logging level to use for logging from launched processes
"""
def __init__(self, interval, has_gpu, main_pid, logging_level):
def __init__(
self,
interval: int | float,
has_gpu: bool,
main_pid: int,
logging_level: int,
):
self.interval = interval
self.has_gpu = has_gpu
self.main_pid = main_pid
self.stop_event = multiprocessing.Event()
self.summary_event = multiprocessing.Event()
self.q = multiprocessing.Queue()
self.q: multiprocessing.Queue[
dict[str, list[int | float]]
] = multiprocessing.Queue()
self.logging_level = logging_level
self.monitor = multiprocessing.Process(
......@@ -393,23 +406,23 @@ class ResourceMonitor:
),
)
self.data = None
@staticmethod
def monitored_keys(has_gpu):
return _InformationGatherer(has_gpu, None, logger).keys
self.data: dict[str, int | float] | None = None
def __enter__(self):
def __enter__(self) -> None:
"""Starts the monitoring process."""
self.monitor.start()
def trigger_summary(self):
def checkpoint(self) -> None:
"""Forces the monitoring process to yield data and clear the internal
accumlator."""
self.summary_event.set()
try:
data = self.q.get(timeout=2 * self.interval)
data: dict[str, list[int | float]] = self.q.get(
timeout=2 * self.interval
)
except queue.Empty:
logger.warn(
logger.warning(
f"CPU/GPU resource monitor did not provide anything when "
f"joined (even after a {2*self.interval}-second timeout - "
f"this is normally due to exceptions on the monitoring process. "
......@@ -417,19 +430,18 @@ class ResourceMonitor:
)
self.data = None
else:
# summarize the returned data by creating means
summary = []
for k, values in data:
# summarize the returned data by creating averages
self.data = {}
for k, values in data.items():
if values:
if k in ("cpu_processes", "cpu_open_files"):
summary.append((k, numpy.max(values)))
self.data[k] = numpy.max(values)
else:
summary.append((k, numpy.mean(values)))
self.data[k] = float(numpy.mean(values))
else:
summary.append((k, 0.0))
self.data = tuple(summary)
self.data[k] = 0.0
def __exit__(self, *exc):
def __exit__(self, *_) -> None:
"""Stops the monitoring process and returns the summary of
observations."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment