From d711a4faaa265a61a6fba8d59ede6cd12cc32f77 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 18 Apr 2023 13:42:57 +0200
Subject: [PATCH] Fixed accelerator assignement

The CUDA_VISIBLE_DEVICES environement variable is set if a device is
specified or we raise an error if it has not been set.
Code for accelerator handling has been moved to its own utils script.
---
 src/ptbench/engine/trainer.py    |  77 +++----------------
 src/ptbench/utils/accelerator.py | 127 +++++++++++++++++++++++++++++++
 2 files changed, 138 insertions(+), 66 deletions(-)
 create mode 100644 src/ptbench/utils/accelerator.py

diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 723c6d66..14e815b0 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -7,76 +7,18 @@ import logging
 import os
 import shutil
 
-import torch
-
 from pytorch_lightning import Trainer
 from pytorch_lightning.callbacks import ModelCheckpoint
 from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
 from pytorch_lightning.utilities.model_summary import ModelSummary
 
+from ..utils.accelerator import AcceleratorProcessor
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
 from .callbacks import LoggingCallback
 
 logger = logging.getLogger(__name__)
 
 
-class AcceleratorProcessor:
-    """This class is used to convert torch devices into lightning accelerators
-    and vice versa, as they do not use the same conventions."""
-
-    def __init__(self):
-        # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now.
-        self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"}
-        self.lightning_to_torch = {
-            v: k for k, v in self.torch_to_lightning.items()
-        }
-        self.valid_accelerators = set(
-            list(self.torch_to_lightning.keys())
-            + list(self.lightning_to_torch.keys())
-        )
-
-    def _split_accelerator_name(self, accelerator_name):
-        split_accelerator = accelerator_name.split(":")
-        accelerator = split_accelerator[0]
-
-        if len(split_accelerator) > 1:
-            devices = split_accelerator[1:]
-            devices = [int(i) for i in devices]
-            os.environ["CUDA_VISIBLE_DEVICES"] = devices
-        else:
-            devices = "auto"
-
-        return accelerator, devices
-
-    def to_torch(self, accelerator_name):
-        accelerator_name, devices = self._split_accelerator_name(
-            accelerator_name
-        )
-
-        assert accelerator_name in self.valid_accelerators
-
-        if accelerator_name in self.lightning_to_torch:
-            return self.lightning_to_torch[accelerator_name], devices
-        elif accelerator_name in self.torch_to_lightning:
-            return accelerator_name, devices
-        else:
-            raise ValueError("Unknown accelerator.")
-
-    def to_lightning(self, accelerator_name):
-        accelerator_name, devices = self._split_accelerator_name(
-            accelerator_name
-        )
-
-        assert accelerator_name in self.valid_accelerators
-
-        if accelerator_name in self.torch_to_lightning:
-            return self.lightning_to_torch[accelerator_name], devices
-        elif accelerator_name in self.lightning_to_torch:
-            return accelerator_name, devices
-        else:
-            raise ValueError("Unknown accelerator.")
-
-
 def check_gpu(device):
     """Check the device type and the availability of GPU.
 
@@ -270,9 +212,7 @@ def run(
 
     max_epoch = arguments["max_epoch"]
 
-    accelerator_processor = AcceleratorProcessor()
-
-    check_gpu(accelerator_processor.to_torch(accelerator)[0])
+    accelerator_processor = AcceleratorProcessor(accelerator)
 
     os.makedirs(output_folder, exist_ok=True)
 
@@ -284,7 +224,7 @@ def run(
 
     resource_monitor = ResourceMonitor(
         interval=monitoring_interval,
-        has_gpu=torch.cuda.is_available(),
+        has_gpu=(accelerator_processor.accelerator == "gpu"),
         main_pid=os.getpid(),
         logging_level=logging.ERROR,
     )
@@ -304,13 +244,18 @@ def run(
     # write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
     static_information_to_csv(
-        static_logfile_name, accelerator_processor.to_torch(accelerator)[0], n
+        static_logfile_name, accelerator_processor.to_torch(), n
     )
 
+    if accelerator_processor.device is None:
+        devices = "auto"
+    else:
+        devices = accelerator_processor.device
+
     with resource_monitor:
         trainer = Trainer(
-            accelerator=accelerator_processor.to_torch(accelerator)[0],
-            devices=accelerator_processor.to_torch(accelerator)[1],
+            accelerator=accelerator_processor.accelerator,
+            devices=devices,
             max_epochs=max_epoch,
             accumulate_grad_batches=batch_chunk_count,
             logger=[csv_logger, tensorboard_logger],
diff --git a/src/ptbench/utils/accelerator.py b/src/ptbench/utils/accelerator.py
new file mode 100644
index 00000000..dcfa2f73
--- /dev/null
+++ b/src/ptbench/utils/accelerator.py
@@ -0,0 +1,127 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+import os
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+class AcceleratorProcessor:
+    """This class is used to convert the torch device naming convention to
+    lightning's device convention and vice versa.
+
+    It also sets the CUDA_VISIBLE_DEVICES if a gpu accelerator is used.
+    """
+
+    def __init__(self, name):
+        # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now.
+        self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"}
+
+        self.lightning_to_torch = {
+            v: k for k, v in self.torch_to_lightning.items()
+        }
+
+        self.valid_accelerators = set(
+            list(self.torch_to_lightning.keys())
+            + list(self.lightning_to_torch.keys())
+        )
+
+        self.accelerator, self.device = self._split_accelerator_name(name)
+
+        if self.accelerator not in self.valid_accelerators:
+            raise ValueError(f"Unknown accelerator {self.accelerator}")
+
+        # Keep lightning's convention by default
+        self.accelerator = self.to_lightning()
+        self.setup_accelerator()
+
+    def setup_accelerator(self):
+        """If a gpu accelerator is chosen, checks the CUDA_VISIBLE_DEVICES
+        environment variable exists or sets its value if specified."""
+        if self.accelerator == "gpu":
+            if not torch.cuda.is_available():
+                raise RuntimeError(
+                    f"CUDA is not currently available, but "
+                    f"you set accelerator to '{self.accelerator}'"
+                )
+
+            if self.device is not None:
+                os.environ["CUDA_VISIBLE_DEVICES"] = str(self.device[0])
+            else:
+                if os.environ.get("CUDA_VISIBLE_DEVICES") is None:
+                    raise ValueError(
+                        "Environment variable 'CUDA_VISIBLE_DEVICES' is not set."
+                        "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0"
+                    )
+        else:
+            # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu
+            pass
+
+        logger.info(
+            f"Accelerator set to {self.accelerator} and device to {self.device}"
+        )
+
+    def _split_accelerator_name(self, accelerator_name):
+        """Splits an accelerator string into accelerator and device components.
+
+        Parameters
+        ----------
+
+        accelerator_name: str
+            The accelerator (or device in pytorch convention) string (e.g. cuda:0)
+
+        Returns
+        -------
+
+        accelerator: str
+            The accelerator name
+        device: dict[int]
+            The selected devices
+        """
+
+        split_accelerator = accelerator_name.split(":")
+        accelerator = split_accelerator[0]
+
+        if len(split_accelerator) > 1:
+            device = split_accelerator[1]
+            device = [int(device)]
+        else:
+            device = None
+
+        return accelerator, device
+
+    def to_torch(self):
+        """Converts the accelerator string to torch convention.
+
+        Returns
+        -------
+
+        accelerator: str
+            The accelerator name in pytorch convention
+        """
+        if self.accelerator in self.lightning_to_torch:
+            return self.lightning_to_torch[self.accelerator]
+        elif self.accelerator in self.torch_to_lightning:
+            return self.accelerator
+        else:
+            raise ValueError("Unknown accelerator.")
+
+    def to_lightning(self):
+        """Converts the accelerator string to lightning convention.
+
+        Returns
+        -------
+
+        accelerator: str
+            The accelerator name in lightning convention
+        """
+        if self.accelerator in self.torch_to_lightning:
+            return self.torch_to_lightning[self.accelerator]
+        elif self.accelerator in self.lightning_to_torch:
+            return self.accelerator
+        else:
+            raise ValueError("Unknown accelerator.")
-- 
GitLab