Skip to content
Snippets Groups Projects

Making use of LightningDataModule and simplification of data loading

Merged Daniel CARRON requested to merge add-datamodule into main
4 files
+ 179
153
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 150
0
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
import torch
import torch.backends
logger = logging.getLogger(__name__)
def _split_int_list(s: str) -> list[int]:
"""Splits a list of integers encoded in a string (e.g. "1,2,3") into a
Python list of integers (e.g. ``[1, 2, 3]``)."""
return [int(k.strip()) for k in s.split(",")]
class DeviceManager:
"""This class is used to manage the Lightning Accelerator and Pytorch
Devices.
It takes the user input, in the form of a string defined by
``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``), and can
translate to the right incarnation of Pytorch devices or Lightning
Accelerators to interface with the various frameworks.
Instances of this class also manage the environment variable
``$CUDA_VISIBLE_DEVICES`` if necessary.
Parameters
----------
name
The name of the device to use, in the form of a string defined by
``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``). In
the specific case of ``cuda``, one can also specify a device to use
either by adding ``:N``, where N is the zero-indexed board number on
the computer, or by setting the environment variable
``$CUDA_VISIBLE_DEVICES`` with the devices that are usable by the
current process.
"""
SUPPORTED = ("cpu", "cuda", "mps")
def __init__(self, name: str):
parts = name.split(":", 1)
self.device_type = parts[0]
self.device_ids: list[int] = []
if len(parts) > 1:
self.device_ids = _split_int_list(parts[1])
if self.device_type == "cuda":
visible_env = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible_env:
visible = _split_int_list(visible_env)
if self.device_ids and visible != self.device_ids:
logger.warning(
f"${{CUDA_VISIBLE_DEVICES}}={visible} and name={name} "
f"- overriding environment with value set on `name`"
)
else:
self.device_ids = visible
# make sure that it is consistent with the environment
if self.device_ids:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(k) for k in self.device_ids]
)
if self.device_type not in DeviceManager.SUPPORTED:
raise RuntimeError(
f"Unsupported device type `{self.device_type}`. "
f"Supported devices types are `{', '.join(DeviceManager.SUPPORTED)}`"
)
if self.device_ids and self.device_type in ("cpu", "mps"):
logger.warning(
f"Cannot pin device ids if using cpu or mps backend. "
f"Setting `name` to {name} is non-sensical. Ignoring..."
)
# check if the device_type that was set has support compiled in
if self.device_type == "cuda":
assert hasattr(torch, "cuda") and torch.cuda.is_available(), (
f"User asked for device = `{name}`, but CUDA support is "
f"not compiled into pytorch!"
)
if self.device_type == "mps":
assert (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available() # type:ignore
), (
f"User asked for device = `{name}`, but MPS support is "
f"not compiled into pytorch!"
)
def torch_device(self) -> torch.device:
"""Returns a representation of the torch device to use by default.
.. warning::
If a list of devices is set, then this method only returns the first
device. This may impact Nvidia GPU logging in the case multiple
GPU cards are used.
Returns
-------
device
The **first** torch device (if a list of ids is set).
"""
if self.device_type in ("cpu", "mps"):
return torch.device(self.device_type)
elif self.device_type == "cuda":
if not self.device_ids:
return torch.device(self.device_type)
else:
return torch.device(self.device_type, self.device_ids[0])
# if you get to this point, this is an unexpected RuntimeError
raise RuntimeError(
f"Unexpected device type {self.device_type} lacks support"
)
def lightning_accelerator(self) -> tuple[str, int | list[int] | str | None]:
"""Returns the lightning accelerator setup.
Returns
-------
accelerator
The lightning accelerator to use
devices
The lightning devices to use
"""
devices: int | list[int] | str = self.device_ids
if not devices:
devices = "auto"
elif self.device_type == "mps":
devices = 1
return self.device_type, devices
Loading