From a0f264f06bb2084554ce6d89c73c9d7df45973c0 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 10 Jul 2023 11:20:13 +0200 Subject: [PATCH] [ptbench.utils.accelerator] Add support for mps backend --- src/ptbench/utils/accelerator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ptbench/utils/accelerator.py b/src/ptbench/utils/accelerator.py index dcfa2f73..42e87a7e 100644 --- a/src/ptbench/utils/accelerator.py +++ b/src/ptbench/utils/accelerator.py @@ -18,8 +18,10 @@ class AcceleratorProcessor: """ def __init__(self, name): - # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now. - self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"} + # Note: "auto" is a valid accelerator in lightning, but there doesn't + # seem to be a way to check which accelerator it will actually use so + # we don't take it into account for now. + self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu", "mps": "mps"} self.lightning_to_torch = { v: k for k, v in self.torch_to_lightning.items() @@ -57,6 +59,8 @@ class AcceleratorProcessor: "Environment variable 'CUDA_VISIBLE_DEVICES' is not set." "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0" ) + elif self.accelerator == "mps": + self.device = 1 else: # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu pass -- GitLab