From a0f264f06bb2084554ce6d89c73c9d7df45973c0 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 10 Jul 2023 11:20:13 +0200
Subject: [PATCH] [ptbench.utils.accelerator] Add support for mps backend

---
 src/ptbench/utils/accelerator.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/ptbench/utils/accelerator.py b/src/ptbench/utils/accelerator.py
index dcfa2f73..42e87a7e 100644
--- a/src/ptbench/utils/accelerator.py
+++ b/src/ptbench/utils/accelerator.py
@@ -18,8 +18,10 @@ class AcceleratorProcessor:
     """
 
     def __init__(self, name):
-        # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now.
-        self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"}
+        # Note: "auto" is a valid accelerator in lightning, but there doesn't
+        # seem to be a way to check which accelerator it will actually use so
+        # we don't take it into account for now.
+        self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu", "mps": "mps"}
 
         self.lightning_to_torch = {
             v: k for k, v in self.torch_to_lightning.items()
@@ -57,6 +59,8 @@ class AcceleratorProcessor:
                         "Environment variable 'CUDA_VISIBLE_DEVICES' is not set."
                         "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0"
                     )
+        elif self.accelerator == "mps":
+            self.device = 1
         else:
             # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu
             pass
-- 
GitLab