Skip to content
Snippets Groups Projects
Commit 8852fa7f authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Re-added device/accelerator selection

parent c2c4286d
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -7,6 +7,8 @@ import logging
import os
import shutil
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
......@@ -18,6 +20,61 @@ from .callbacks import LoggingCallback
logger = logging.getLogger(__name__)
class AcceleratorProcessor:
"""This class is used to convert torch devices into lightning accelerators
and vice versa, as they do not use the same conventions."""
def __init__(self):
# Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now.
self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"}
self.lightning_to_torch = {
v: k for k, v in self.torch_to_lightning.items()
}
self.valid_accelerators = set(
list(self.torch_to_lightning.keys())
+ list(self.lightning_to_torch.keys())
)
def _split_accelerator_name(self, accelerator_name):
split_accelerator = accelerator_name.split(":")
accelerator = split_accelerator[0]
if len(split_accelerator) > 1:
devices = split_accelerator[1:]
else:
devices = "auto"
return accelerator, devices
def to_torch(self, accelerator_name):
accelerator_name, devices = self._split_accelerator_name(
accelerator_name
)
assert accelerator_name in self.valid_accelerators
if accelerator_name in self.lightning_to_torch:
return self.lightning_to_torch[accelerator_name], devices
elif accelerator_name in self.torch_to_lightning:
return accelerator_name, devices
else:
raise ValueError("Unknown accelerator.")
def to_lightning(self, accelerator_name):
accelerator_name, devices = self._split_accelerator_name(
accelerator_name
)
assert accelerator_name in self.valid_accelerators
if accelerator_name in self.torch_to_lightning:
return self.lightning_to_torch[accelerator_name], devices
elif accelerator_name in self.lightning_to_torch:
return accelerator_name, devices
else:
raise ValueError("Unknown accelerator.")
def check_gpu(device):
"""Check the device type and the availability of GPU.
......@@ -27,7 +84,7 @@ def check_gpu(device):
device : :py:class:`torch.device`
device to use
"""
if device.type == "cuda":
if device == "cuda":
# asserts we do have a GPU
assert bool(
gpu_constants()
......@@ -78,7 +135,7 @@ def static_information_to_csv(static_logfile_name, device, n):
shutil.move(static_logfile_name, backup)
with open(static_logfile_name, "w", newline="") as f:
logdata = cpu_constants()
if device.type == "cuda":
if device == "cuda":
logdata += gpu_constants()
logdata += (("model_size", n),)
logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
......@@ -142,9 +199,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device):
logfile_fields += ("validation_loss",)
if extra_valid_loaders:
logfile_fields += ("extra_validation_losses",)
logfile_fields += tuple(
ResourceMonitor.monitored_keys(device.type == "cuda")
)
logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda"))
return logfile_fields
......@@ -154,7 +209,7 @@ def run(
valid_loader,
extra_valid_loaders,
checkpoint_period,
device,
accelerator,
arguments,
output_folder,
monitoring_interval,
......@@ -190,8 +245,8 @@ def run(
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints
device : :py:class:`torch.device`
device to use
accelerator : str`
accelerator to use
arguments : dict
start and end epochs
......@@ -213,7 +268,9 @@ def run(
max_epoch = arguments["max_epoch"]
check_gpu(device)
accelerator_processor = AcceleratorProcessor()
check_gpu(accelerator_processor.to_torch(accelerator)[0])
os.makedirs(output_folder, exist_ok=True)
......@@ -225,7 +282,7 @@ def run(
resource_monitor = ResourceMonitor(
interval=monitoring_interval,
has_gpu=(device.type == "cuda"),
has_gpu=torch.cuda.is_available(),
main_pid=os.getpid(),
logging_level=logging.ERROR,
)
......@@ -244,12 +301,14 @@ def run(
# write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv")
static_information_to_csv(static_logfile_name, device, n)
static_information_to_csv(
static_logfile_name, accelerator_processor.to_torch(accelerator)[0], n
)
with resource_monitor:
trainer = Trainer(
accelerator="auto",
devices="auto",
accelerator=accelerator_processor.to_torch(accelerator)[0],
devices=accelerator_processor.to_torch(accelerator)[1],
max_epochs=max_epoch,
accumulate_grad_batches=batch_chunk_count,
logger=[csv_logger, tensorboard_logger],
......
......@@ -13,47 +13,6 @@ from pytorch_lightning import seed_everything
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def setup_pytorch_device(name):
"""Sets-up the pytorch device to use.
Parameters
----------
name : str
The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you
set a specific cuda device such as ``cuda:1``, then we'll make sure it
is currently set.
Returns
-------
device : :py:class:`torch.device`
The pytorch device to use, pre-configured (and checked)
"""
import torch
if name.startswith("cuda:"):
# In case one has multiple devices, we must first set the one
# we would like to use so pytorch can find it.
logger.info(f"User set device to '{name}' - trying to force device...")
os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1]
if not torch.cuda.is_available():
raise RuntimeError(
f"CUDA is not currently available, but "
f"you set device to '{name}'"
)
# Let pytorch auto-select from environment variable
return torch.device("cuda")
elif name.startswith("cuda"): # use default device
logger.info(f"User set device to '{name}' - using default CUDA device")
assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None
# cuda or cpu
return torch.device(name)
def set_reproducible_cuda():
"""Turns-off all CUDA optimizations that would affect reproducibility.
......@@ -217,12 +176,12 @@ def set_reproducible_cuda():
cls=ResourceOption,
)
@click.option(
"--device",
"-d",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
"--accelerator",
"-a",
help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available',
show_default=True,
required=True,
default="cpu",
default="auto",
cls=ResourceOption,
)
@click.option(
......@@ -294,7 +253,7 @@ def train(
criterion_valid,
dataset,
checkpoint_period,
device,
accelerator,
seed,
parallel,
normalization,
......@@ -323,8 +282,6 @@ def train(
from ..configs.datasets import get_positive_weights, get_samples_weights
from ..engine.trainer import run
device = setup_pytorch_device(device)
seed_everything(seed)
use_dataset = dataset
......@@ -517,7 +474,7 @@ def train(
valid_loader=valid_loader,
extra_valid_loaders=extra_valid_loaders,
checkpoint_period=checkpoint_period,
device=device,
accelerator=accelerator,
arguments=arguments,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment