Skip to content
Snippets Groups Projects
Commit 8852fa7f authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Re-added device/accelerator selection

parent c2c4286d
No related branches found
No related tags found
1 merge request!4Moved code to lightning
...@@ -7,6 +7,8 @@ import logging ...@@ -7,6 +7,8 @@ import logging
import os import os
import shutil import shutil
import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
...@@ -18,6 +20,61 @@ from .callbacks import LoggingCallback ...@@ -18,6 +20,61 @@ from .callbacks import LoggingCallback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AcceleratorProcessor:
"""This class is used to convert torch devices into lightning accelerators
and vice versa, as they do not use the same conventions."""
def __init__(self):
# Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now.
self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"}
self.lightning_to_torch = {
v: k for k, v in self.torch_to_lightning.items()
}
self.valid_accelerators = set(
list(self.torch_to_lightning.keys())
+ list(self.lightning_to_torch.keys())
)
def _split_accelerator_name(self, accelerator_name):
split_accelerator = accelerator_name.split(":")
accelerator = split_accelerator[0]
if len(split_accelerator) > 1:
devices = split_accelerator[1:]
else:
devices = "auto"
return accelerator, devices
def to_torch(self, accelerator_name):
accelerator_name, devices = self._split_accelerator_name(
accelerator_name
)
assert accelerator_name in self.valid_accelerators
if accelerator_name in self.lightning_to_torch:
return self.lightning_to_torch[accelerator_name], devices
elif accelerator_name in self.torch_to_lightning:
return accelerator_name, devices
else:
raise ValueError("Unknown accelerator.")
def to_lightning(self, accelerator_name):
accelerator_name, devices = self._split_accelerator_name(
accelerator_name
)
assert accelerator_name in self.valid_accelerators
if accelerator_name in self.torch_to_lightning:
return self.lightning_to_torch[accelerator_name], devices
elif accelerator_name in self.lightning_to_torch:
return accelerator_name, devices
else:
raise ValueError("Unknown accelerator.")
def check_gpu(device): def check_gpu(device):
"""Check the device type and the availability of GPU. """Check the device type and the availability of GPU.
...@@ -27,7 +84,7 @@ def check_gpu(device): ...@@ -27,7 +84,7 @@ def check_gpu(device):
device : :py:class:`torch.device` device : :py:class:`torch.device`
device to use device to use
""" """
if device.type == "cuda": if device == "cuda":
# asserts we do have a GPU # asserts we do have a GPU
assert bool( assert bool(
gpu_constants() gpu_constants()
...@@ -78,7 +135,7 @@ def static_information_to_csv(static_logfile_name, device, n): ...@@ -78,7 +135,7 @@ def static_information_to_csv(static_logfile_name, device, n):
shutil.move(static_logfile_name, backup) shutil.move(static_logfile_name, backup)
with open(static_logfile_name, "w", newline="") as f: with open(static_logfile_name, "w", newline="") as f:
logdata = cpu_constants() logdata = cpu_constants()
if device.type == "cuda": if device == "cuda":
logdata += gpu_constants() logdata += gpu_constants()
logdata += (("model_size", n),) logdata += (("model_size", n),)
logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
...@@ -142,9 +199,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device): ...@@ -142,9 +199,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device):
logfile_fields += ("validation_loss",) logfile_fields += ("validation_loss",)
if extra_valid_loaders: if extra_valid_loaders:
logfile_fields += ("extra_validation_losses",) logfile_fields += ("extra_validation_losses",)
logfile_fields += tuple( logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda"))
ResourceMonitor.monitored_keys(device.type == "cuda")
)
return logfile_fields return logfile_fields
...@@ -154,7 +209,7 @@ def run( ...@@ -154,7 +209,7 @@ def run(
valid_loader, valid_loader,
extra_valid_loaders, extra_valid_loaders,
checkpoint_period, checkpoint_period,
device, accelerator,
arguments, arguments,
output_folder, output_folder,
monitoring_interval, monitoring_interval,
...@@ -190,8 +245,8 @@ def run( ...@@ -190,8 +245,8 @@ def run(
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints not save intermediary checkpoints
device : :py:class:`torch.device` accelerator : str`
device to use accelerator to use
arguments : dict arguments : dict
start and end epochs start and end epochs
...@@ -213,7 +268,9 @@ def run( ...@@ -213,7 +268,9 @@ def run(
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
check_gpu(device) accelerator_processor = AcceleratorProcessor()
check_gpu(accelerator_processor.to_torch(accelerator)[0])
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
...@@ -225,7 +282,7 @@ def run( ...@@ -225,7 +282,7 @@ def run(
resource_monitor = ResourceMonitor( resource_monitor = ResourceMonitor(
interval=monitoring_interval, interval=monitoring_interval,
has_gpu=(device.type == "cuda"), has_gpu=torch.cuda.is_available(),
main_pid=os.getpid(), main_pid=os.getpid(),
logging_level=logging.ERROR, logging_level=logging.ERROR,
) )
...@@ -244,12 +301,14 @@ def run( ...@@ -244,12 +301,14 @@ def run(
# write static information to a CSV file # write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv") static_logfile_name = os.path.join(output_folder, "constants.csv")
static_information_to_csv(static_logfile_name, device, n) static_information_to_csv(
static_logfile_name, accelerator_processor.to_torch(accelerator)[0], n
)
with resource_monitor: with resource_monitor:
trainer = Trainer( trainer = Trainer(
accelerator="auto", accelerator=accelerator_processor.to_torch(accelerator)[0],
devices="auto", devices=accelerator_processor.to_torch(accelerator)[1],
max_epochs=max_epoch, max_epochs=max_epoch,
accumulate_grad_batches=batch_chunk_count, accumulate_grad_batches=batch_chunk_count,
logger=[csv_logger, tensorboard_logger], logger=[csv_logger, tensorboard_logger],
......
...@@ -13,47 +13,6 @@ from pytorch_lightning import seed_everything ...@@ -13,47 +13,6 @@ from pytorch_lightning import seed_everything
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def setup_pytorch_device(name):
"""Sets-up the pytorch device to use.
Parameters
----------
name : str
The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you
set a specific cuda device such as ``cuda:1``, then we'll make sure it
is currently set.
Returns
-------
device : :py:class:`torch.device`
The pytorch device to use, pre-configured (and checked)
"""
import torch
if name.startswith("cuda:"):
# In case one has multiple devices, we must first set the one
# we would like to use so pytorch can find it.
logger.info(f"User set device to '{name}' - trying to force device...")
os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1]
if not torch.cuda.is_available():
raise RuntimeError(
f"CUDA is not currently available, but "
f"you set device to '{name}'"
)
# Let pytorch auto-select from environment variable
return torch.device("cuda")
elif name.startswith("cuda"): # use default device
logger.info(f"User set device to '{name}' - using default CUDA device")
assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None
# cuda or cpu
return torch.device(name)
def set_reproducible_cuda(): def set_reproducible_cuda():
"""Turns-off all CUDA optimizations that would affect reproducibility. """Turns-off all CUDA optimizations that would affect reproducibility.
...@@ -217,12 +176,12 @@ def set_reproducible_cuda(): ...@@ -217,12 +176,12 @@ def set_reproducible_cuda():
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--device", "--accelerator",
"-d", "-a",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available',
show_default=True, show_default=True,
required=True, required=True,
default="cpu", default="auto",
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
...@@ -294,7 +253,7 @@ def train( ...@@ -294,7 +253,7 @@ def train(
criterion_valid, criterion_valid,
dataset, dataset,
checkpoint_period, checkpoint_period,
device, accelerator,
seed, seed,
parallel, parallel,
normalization, normalization,
...@@ -323,8 +282,6 @@ def train( ...@@ -323,8 +282,6 @@ def train(
from ..configs.datasets import get_positive_weights, get_samples_weights from ..configs.datasets import get_positive_weights, get_samples_weights
from ..engine.trainer import run from ..engine.trainer import run
device = setup_pytorch_device(device)
seed_everything(seed) seed_everything(seed)
use_dataset = dataset use_dataset = dataset
...@@ -517,7 +474,7 @@ def train( ...@@ -517,7 +474,7 @@ def train(
valid_loader=valid_loader, valid_loader=valid_loader,
extra_valid_loaders=extra_valid_loaders, extra_valid_loaders=extra_valid_loaders,
checkpoint_period=checkpoint_period, checkpoint_period=checkpoint_period,
device=device, accelerator=accelerator,
arguments=arguments, arguments=arguments,
output_folder=output_folder, output_folder=output_folder,
monitoring_interval=monitoring_interval, monitoring_interval=monitoring_interval,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment