Skip to content
Snippets Groups Projects
Commit b0b77c29 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.resources] Implement GPU logger for MPS backend (closes #7)

parent c203a4f7
No related branches found
No related tags found
1 merge request!14Implements GPU monitoring when using the MPS (macOS) compute backend
Pipeline #83491 passed
......@@ -29,12 +29,20 @@ class LoggingCallback(lightning.pytorch.Callback):
Parameters
----------
resource_monitor
train_resource_monitor
A monitor that watches resource usage (CPU/GPU) in a separate process
and totally asynchronously with the code execution.
validation_resource_monitor
A monitor that watches resource usage (CPU/GPU) in a separate process
and totally asynchronously with the code execution.
"""
def __init__(self, resource_monitor: ResourceMonitor):
def __init__(
self,
train_resource_monitor: ResourceMonitor,
validation_resource_monitor: ResourceMonitor,
):
super().__init__()
# timers
......@@ -46,7 +54,8 @@ class LoggingCallback(lightning.pytorch.Callback):
self._to_log: dict[str, float] = {}
# helpers for CPU and GPU utilisation
self._resource_monitor = resource_monitor
self._train_resource_monitor = train_resource_monitor
self._validation_resource_monitor = validation_resource_monitor
self._max_queue_retries = 2
def on_train_start(
......@@ -97,6 +106,10 @@ class LoggingCallback(lightning.pytorch.Callback):
pl_module
The lightning module that is being trained
"""
# summarizes resource usage since the last checkpoint
# clears internal buffers and starts accumulating again.
self._train_resource_monitor.checkpoint(remove_last_n=-1)
self._start_training_epoch_time = time.time()
def on_train_epoch_end(
......@@ -121,17 +134,13 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained
"""
# summarizes resource usage since the last checkpoint
# clears internal buffers and starts accumulating again.
self._resource_monitor.checkpoint()
# evaluates this training epoch total time, and log it
epoch_time = time.time() - self._start_training_epoch_time
self._to_log["epoch-duration-seconds/train"] = epoch_time
self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore
metrics = self._resource_monitor.data
metrics = self._train_resource_monitor.data
if metrics is not None:
for metric_name, metric_value in metrics.items():
self._to_log[f"{metric_name}/train"] = float(metric_value)
......@@ -235,6 +244,14 @@ class LoggingCallback(lightning.pytorch.Callback):
pl_module
The lightning module that is being trained
"""
# required because the validation epoch is started **within** the
# training epoch START/END.
#
# summarizes resource usage since the last checkpoint
# clears internal buffers and starts accumulating again.
self._train_resource_monitor.checkpoint(remove_last_n=-1)
self._validation_resource_monitor.checkpoint(remove_last_n=-1)
self._start_validation_epoch_time = time.time()
def on_validation_epoch_end(
......@@ -261,12 +278,12 @@ class LoggingCallback(lightning.pytorch.Callback):
# summarizes resource usage since the last checkpoint
# clears internal buffers and starts accumulating again.
self._resource_monitor.checkpoint()
self._validation_resource_monitor.checkpoint(remove_last_n=-1)
epoch_time = time.time() - self._start_validation_epoch_time
self._to_log["epoch-duration-seconds/validation"] = epoch_time
metrics = self._resource_monitor.data
metrics = self._validation_resource_monitor.data
if metrics is not None:
for metric_name, metric_value in metrics.items():
self._to_log[f"{metric_name}/validation"] = float(metric_value)
......
......@@ -4,6 +4,7 @@
import logging
import os
import typing
import torch
import torch.backends
......@@ -11,6 +12,14 @@ import torch.backends
logger = logging.getLogger(__name__)
SupportedPytorchDevice: typing.TypeAlias = typing.Literal[
"cpu",
"cuda",
"mps",
]
"""List of supported pytorch devices by this library."""
def _split_int_list(s: str) -> list[int]:
"""Splits a list of integers encoded in a string (e.g. "1,2,3") into a
Python list of integers (e.g. ``[1, 2, 3]``)."""
......@@ -43,11 +52,16 @@ class DeviceManager:
current process.
"""
SUPPORTED = ("cpu", "cuda", "mps")
def __init__(self, name: str):
def __init__(self, name: SupportedPytorchDevice):
parts = name.split(":", 1)
self.device_type = parts[0]
# make device type of the right Python type
if parts[0] not in typing.get_args(SupportedPytorchDevice):
raise ValueError(f"Unsupported device-type `{parts[0]}`")
self.device_type: SupportedPytorchDevice = typing.cast(
SupportedPytorchDevice, parts[0]
)
self.device_ids: list[int] = []
if len(parts) > 1:
self.device_ids = _split_int_list(parts[1])
......@@ -70,10 +84,11 @@ class DeviceManager:
[str(k) for k in self.device_ids]
)
if self.device_type not in DeviceManager.SUPPORTED:
if self.device_type not in typing.get_args(SupportedPytorchDevice):
raise RuntimeError(
f"Unsupported device type `{self.device_type}`. "
f"Supported devices types are `{', '.join(DeviceManager.SUPPORTED)}`"
f"Supported devices types are "
f"`{', '.join(typing.get_args(SupportedPytorchDevice))}`"
)
if self.device_ids and self.device_type in ("cpu", "mps"):
......
......@@ -14,9 +14,9 @@ import lightning.pytorch.loggers
import torch.nn
from ..utils.checkpointer import CHECKPOINT_ALIASES
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from ..utils.resources import ResourceMonitor, cpu_constants, cuda_constants
from .callbacks import LoggingCallback
from .device import DeviceManager
from .device import DeviceManager, SupportedPytorchDevice
logger = logging.getLogger(__name__)
......@@ -62,7 +62,7 @@ def save_model_summary(
def static_information_to_csv(
static_logfile_name: pathlib.Path,
device_type: str,
device_type: SupportedPytorchDevice,
model_size: int,
) -> None:
"""Saves the static information in a CSV file.
......@@ -87,10 +87,19 @@ def static_information_to_csv(
with static_logfile_name.open("w", newline="") as f:
logdata: dict[str, int | float | str] = {}
logdata.update(cpu_constants())
if device_type == "cuda":
results = gpu_constants()
if results is not None:
logdata.update(results)
match device_type:
case "cpu":
pass
case "cuda":
results = cuda_constants()
if results is not None:
logdata.update(results)
case "mps":
pass
case _:
pass
logdata["model_size"] = model_size
logwriter = csv.DictWriter(f, fieldnames=logdata.keys())
logwriter.writeheader()
......@@ -177,9 +186,16 @@ def run(
f"Then, open a browser on the printed address."
)
resource_monitor = ResourceMonitor(
train_resource_monitor = ResourceMonitor(
interval=monitoring_interval,
device_type=device_manager.device_type,
main_pid=os.getpid(),
logging_level=logging.ERROR,
)
validation_resource_monitor = ResourceMonitor(
interval=monitoring_interval,
has_gpu=device_manager.device_type == "cuda",
device_type=device_manager.device_type,
main_pid=os.getpid(),
logging_level=logging.ERROR,
)
......@@ -208,7 +224,7 @@ def run(
no_of_parameters,
)
with resource_monitor:
with train_resource_monitor, validation_resource_monitor:
accelerator, devices = device_manager.lightning_accelerator()
trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
......@@ -219,7 +235,9 @@ def run(
check_val_every_n_epoch=validation_period,
log_every_n_steps=len(datamodule.train_dataloader()),
callbacks=[
LoggingCallback(resource_monitor),
LoggingCallback(
train_resource_monitor, validation_resource_monitor
),
checkpoint_minvalloss_callback,
],
)
......
......@@ -7,6 +7,7 @@ import logging
import multiprocessing
import multiprocessing.synchronize
import os
import plistlib
import queue
import shutil
import subprocess
......@@ -16,6 +17,8 @@ import typing
import numpy
import psutil
from ..engine.device import SupportedPytorchDevice
logger = logging.getLogger(__name__)
_nvidia_smi = shutil.which("nvidia-smi")
......@@ -74,7 +77,7 @@ def run_nvidia_smi(
return retval
def gpu_constants() -> dict[str, str | int | float] | None:
def cuda_constants() -> dict[str, str | int | float] | None:
"""Returns GPU (static) information using nvidia-smi.
See :py:func:`run_nvidia_smi` for operational details.
......@@ -102,7 +105,7 @@ def gpu_constants() -> dict[str, str | int | float] | None:
return retval
def gpu_log() -> dict[str, float] | None:
def cuda_log() -> dict[str, float] | None:
"""Returns GPU information about current non-static status using nvidia-
smi.
......@@ -116,13 +119,13 @@ def gpu_log() -> dict[str, float] | None:
return a dictionary containing the following ``nvidia-smi`` query
information, in this order:
* ``memory.used``, as ``gpu_memory_used`` (transformed to gigabytes,
* ``memory.used``, as ``memory-used-GB/gpu`` (transformed to gigabytes,
:py:class:`float`)
* ``memory.free``, as ``gpu_memory_free`` (transformed to gigabytes,
* ``memory.free``, as ``memory-free-GB/gpu`` (transformed to gigabytes,
:py:class:`float`)
* ``100*memory.used/memory.total``, as ``gpu_memory_percent``,
* ``100*memory.used/memory.total``, as ``memory-percent/gpu``,
(:py:class:`float`, in percent)
* ``utilization.gpu``, as ``gpu_percent``,
* ``utilization.gpu``, as ``percent-usage/gpu``,
(:py:class:`float`, in percent)
"""
......@@ -143,6 +146,66 @@ def gpu_log() -> dict[str, float] | None:
}
def mps_log() -> dict[str, float] | None:
"""Returns GPU information about current non-static status using ``sudo
powermetrics``.
Returns
-------
data
If ``sudo powermetrics`` is not executable (or is not configured for
passwordless execution), returns ``None``, otherwise, we return a
dictionary containing the following query information, in this order:
* ``freq_hz`` as ``frequency-MHz/gpu``
* 100 * (1 - ``idle_ratio``), as ``percent-usage/gpu``,
(:py:class:`float`, in percent)
"""
cmd = [
"sudo",
"-n",
"/usr/bin/powermetrics",
"--samplers",
"gpu_power",
"-i500",
"-n1",
"-fplist",
]
try:
raw_bytes = subprocess.check_output(cmd)
data = plistlib.loads(raw_bytes)
return {
"frequency-MHz/gpu": float(data["gpu"]["freq_hz"]),
"percent-usage/gpu": 100 * (1 - data["gpu"]["idle_ratio"]),
}
except subprocess.CalledProcessError:
import inspect
import warnings
warnings.warn(
inspect.cleandoc(
f"""Cannot run `sudo powermetrics` without a password. Probably,
you did not setup sudo to execute the macOS CLI application
`powermetrics` passwordlessly and therefore this warning is
being issued. This does not affect the running of your model
training, only the ability of the resource monitor of gathering
GPU usage information on macOS while using the MPS compute
backend. To configure this properly and get rid of this
warning, execute `sudo visudo` and add the following line where
suitable: `yourusername ALL=(ALL) NOPASSWD:SETENV:
/usr/bin/powermetrics`. Replace `yourusername` by your actual
username on the machine. Test the setup running the command
`{' '.join(cmd)}` by hand."""
)
)
return None
def cpu_constants() -> dict[str, int | float]:
"""Returns static CPU information about the current system.
......@@ -179,7 +242,7 @@ class CPULogger:
[k.cpu_percent(interval=None) for k in self.cluster]
def log(self) -> dict[str, int | float]:
"""Returns current process cluster iformation.
"""Returns current process cluster information.
Returns
-------
......@@ -252,8 +315,9 @@ class _InformationGatherer:
Parameters
----------
has_gpu
A flag indicating if we have a GPU installed on the platform or not
device_type
String representation of one of the supported pytorch device types
triggering the correct readout of resource usage.
main_pid
The main process identifier to monitor
......@@ -263,27 +327,58 @@ class _InformationGatherer:
"""
def __init__(
self, has_gpu: bool, main_pid: int | None, logger: logging.Logger
self,
device_type: SupportedPytorchDevice,
main_pid: int | None,
logger: logging.Logger,
):
self.logger: logging.Logger = logger
self.cpu_logger = CPULogger(main_pid)
keys: list[str] = list(self.cpu_logger.log().keys())
self.has_gpu: bool = has_gpu
if self.has_gpu:
example = gpu_log()
if example is not None:
keys += list(example.keys())
self.device_type = device_type
match self.device_type:
case "cpu":
logger.info(
f"Pytorch device-type `{device_type}`: "
f"no GPU logging will be performed "
)
case "cuda":
example = cuda_log()
if example is not None:
keys += list(example.keys())
case "mps":
example = mps_log()
if example is not None:
keys += list(example.keys())
case _:
logger.warning(
f"Unsupported device-type `{device_type}`: "
f"no GPU logging will be performed "
)
self.data: dict[str, list[int | float]] = {k: [] for k in keys}
def acc(self) -> None:
"""Accumulates another measurement."""
for k, v in self.cpu_logger.log().items():
self.data[k].append(v)
if self.has_gpu:
sample = gpu_log()
if sample is not None:
for k, v in sample.items():
self.data[k].append(v)
match self.device_type:
case "cpu":
pass
case "cuda":
sample = cuda_log()
if sample is not None:
for k, v in sample.items():
self.data[k].append(v)
case "mps":
sample = mps_log()
if sample is not None:
for k, v in sample.items():
self.data[k].append(v)
case _:
pass
def clear(self) -> None:
"""Clears accumulated data."""
......@@ -299,7 +394,7 @@ class _InformationGatherer:
def _monitor_worker(
interval: int | float,
has_gpu: bool,
device_type: SupportedPytorchDevice,
main_pid: int,
stop: multiprocessing.synchronize.Event,
summary_event: multiprocessing.synchronize.Event,
......@@ -315,8 +410,9 @@ def _monitor_worker(
Number of seconds to wait between each measurement (maybe a floating
point number as accepted by :py:func:`time.sleep`)
has_gpu
A flag indicating if we have a GPU installed on the platform or not
device_type
String representation of one of the supported pytorch device types
triggering the correct readout of resource usage.
main_pid
The main process identifier to monitor
......@@ -334,7 +430,7 @@ def _monitor_worker(
The logging level to use for logging from launched processes
"""
logger = multiprocessing.log_to_stderr(level=logging_level)
ra = _InformationGatherer(has_gpu, main_pid, logger)
ra = _InformationGatherer(device_type, main_pid, logger)
while not stop.is_set():
try:
......@@ -364,8 +460,9 @@ class ResourceMonitor:
Number of seconds to wait between each measurement (maybe a floating
point number as accepted by :py:func:`time.sleep`)
has_gpu
A flag indicating if we have a GPU installed on the platform or not
device_type
String representation of one of the supported pytorch device types
triggering the correct readout of resource usage.
main_pid
The main process identifier to monitor
......@@ -377,12 +474,12 @@ class ResourceMonitor:
def __init__(
self,
interval: int | float,
has_gpu: bool,
device_type: SupportedPytorchDevice,
main_pid: int,
logging_level: int,
):
self.interval = interval
self.has_gpu = has_gpu
self.device_type = device_type
self.main_pid = main_pid
self.stop_event = multiprocessing.Event()
self.summary_event = multiprocessing.Event()
......@@ -396,7 +493,7 @@ class ResourceMonitor:
name="ResourceMonitorProcess",
args=(
self.interval,
self.has_gpu,
self.device_type,
self.main_pid,
self.stop_event,
self.summary_event,
......@@ -411,9 +508,17 @@ class ResourceMonitor:
"""Starts the monitoring process."""
self.monitor.start()
def checkpoint(self) -> None:
def checkpoint(self, remove_last_n: int | None = None) -> None:
"""Forces the monitoring process to yield data and clear the internal
accumlator."""
accumulator.
Parameters
----------
remove_last_n
If set, then remove the last observations from all entries if at
least one entry is kept. Useful to remove spurious observations by
the end of a period.
"""
self.summary_event.set()
try:
data: dict[str, list[int | float]] = self.q.get(
......@@ -432,7 +537,9 @@ class ResourceMonitor:
self.data = {}
for k, values in data.items():
if values:
if k in ("cpu_processes", "cpu_open_files"):
if values[:remove_last_n]:
values = values[:remove_last_n]
if k in ("num-processes/cpu", "num-open-files/cpu"):
self.data[k] = numpy.max(values)
else:
self.data[k] = float(numpy.mean(values))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment