diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index bb7fedab436e4ec68db632c05194db2bea61ae16..0db5386cba01a34724b08c714c7e999983f90a3d 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -7,6 +7,8 @@ import logging import os import shutil +import torch + from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger @@ -18,6 +20,61 @@ from .callbacks import LoggingCallback logger = logging.getLogger(__name__) +class AcceleratorProcessor: + """This class is used to convert torch devices into lightning accelerators + and vice versa, as they do not use the same conventions.""" + + def __init__(self): + # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now. + self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"} + self.lightning_to_torch = { + v: k for k, v in self.torch_to_lightning.items() + } + self.valid_accelerators = set( + list(self.torch_to_lightning.keys()) + + list(self.lightning_to_torch.keys()) + ) + + def _split_accelerator_name(self, accelerator_name): + split_accelerator = accelerator_name.split(":") + accelerator = split_accelerator[0] + + if len(split_accelerator) > 1: + devices = split_accelerator[1:] + else: + devices = "auto" + + return accelerator, devices + + def to_torch(self, accelerator_name): + accelerator_name, devices = self._split_accelerator_name( + accelerator_name + ) + + assert accelerator_name in self.valid_accelerators + + if accelerator_name in self.lightning_to_torch: + return self.lightning_to_torch[accelerator_name], devices + elif accelerator_name in self.torch_to_lightning: + return accelerator_name, devices + else: + raise ValueError("Unknown accelerator.") + + def to_lightning(self, accelerator_name): + accelerator_name, devices = self._split_accelerator_name( + accelerator_name + ) + + assert accelerator_name in self.valid_accelerators + + if accelerator_name in self.torch_to_lightning: + return self.lightning_to_torch[accelerator_name], devices + elif accelerator_name in self.lightning_to_torch: + return accelerator_name, devices + else: + raise ValueError("Unknown accelerator.") + + def check_gpu(device): """Check the device type and the availability of GPU. @@ -27,7 +84,7 @@ def check_gpu(device): device : :py:class:`torch.device` device to use """ - if device.type == "cuda": + if device == "cuda": # asserts we do have a GPU assert bool( gpu_constants() @@ -78,7 +135,7 @@ def static_information_to_csv(static_logfile_name, device, n): shutil.move(static_logfile_name, backup) with open(static_logfile_name, "w", newline="") as f: logdata = cpu_constants() - if device.type == "cuda": + if device == "cuda": logdata += gpu_constants() logdata += (("model_size", n),) logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata]) @@ -142,9 +199,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device): logfile_fields += ("validation_loss",) if extra_valid_loaders: logfile_fields += ("extra_validation_losses",) - logfile_fields += tuple( - ResourceMonitor.monitored_keys(device.type == "cuda") - ) + logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda")) return logfile_fields @@ -154,7 +209,7 @@ def run( valid_loader, extra_valid_loaders, checkpoint_period, - device, + accelerator, arguments, output_folder, monitoring_interval, @@ -190,8 +245,8 @@ def run( save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do not save intermediary checkpoints - device : :py:class:`torch.device` - device to use + accelerator : str` + accelerator to use arguments : dict start and end epochs @@ -213,7 +268,9 @@ def run( max_epoch = arguments["max_epoch"] - check_gpu(device) + accelerator_processor = AcceleratorProcessor() + + check_gpu(accelerator_processor.to_torch(accelerator)[0]) os.makedirs(output_folder, exist_ok=True) @@ -225,7 +282,7 @@ def run( resource_monitor = ResourceMonitor( interval=monitoring_interval, - has_gpu=(device.type == "cuda"), + has_gpu=torch.cuda.is_available(), main_pid=os.getpid(), logging_level=logging.ERROR, ) @@ -244,12 +301,14 @@ def run( # write static information to a CSV file static_logfile_name = os.path.join(output_folder, "constants.csv") - static_information_to_csv(static_logfile_name, device, n) + static_information_to_csv( + static_logfile_name, accelerator_processor.to_torch(accelerator)[0], n + ) with resource_monitor: trainer = Trainer( - accelerator="auto", - devices="auto", + accelerator=accelerator_processor.to_torch(accelerator)[0], + devices=accelerator_processor.to_torch(accelerator)[1], max_epochs=max_epoch, accumulate_grad_batches=batch_chunk_count, logger=[csv_logger, tensorboard_logger], diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 64f66a6a03c713521463907c340027948c0a07ad..b6ef37bcade2708b5507eef5f31f27b8eee25bc7 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -13,47 +13,6 @@ from pytorch_lightning import seed_everything logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -def setup_pytorch_device(name): - """Sets-up the pytorch device to use. - - Parameters - ---------- - - name : str - The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on). If you - set a specific cuda device such as ``cuda:1``, then we'll make sure it - is currently set. - - - Returns - ------- - - device : :py:class:`torch.device` - The pytorch device to use, pre-configured (and checked) - """ - import torch - - if name.startswith("cuda:"): - # In case one has multiple devices, we must first set the one - # we would like to use so pytorch can find it. - logger.info(f"User set device to '{name}' - trying to force device...") - os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1] - if not torch.cuda.is_available(): - raise RuntimeError( - f"CUDA is not currently available, but " - f"you set device to '{name}'" - ) - # Let pytorch auto-select from environment variable - return torch.device("cuda") - - elif name.startswith("cuda"): # use default device - logger.info(f"User set device to '{name}' - using default CUDA device") - assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None - - # cuda or cpu - return torch.device(name) - - def set_reproducible_cuda(): """Turns-off all CUDA optimizations that would affect reproducibility. @@ -217,12 +176,12 @@ def set_reproducible_cuda(): cls=ResourceOption, ) @click.option( - "--device", - "-d", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + "--accelerator", + "-a", + help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available', show_default=True, required=True, - default="cpu", + default="auto", cls=ResourceOption, ) @click.option( @@ -294,7 +253,7 @@ def train( criterion_valid, dataset, checkpoint_period, - device, + accelerator, seed, parallel, normalization, @@ -323,8 +282,6 @@ def train( from ..configs.datasets import get_positive_weights, get_samples_weights from ..engine.trainer import run - device = setup_pytorch_device(device) - seed_everything(seed) use_dataset = dataset @@ -517,7 +474,7 @@ def train( valid_loader=valid_loader, extra_valid_loaders=extra_valid_loaders, checkpoint_period=checkpoint_period, - device=device, + accelerator=accelerator, arguments=arguments, output_folder=output_folder, monitoring_interval=monitoring_interval,