diff --git a/src/ptbench/utils/accelerator.py b/src/ptbench/utils/accelerator.py index dcfa2f733e1d091c5bb9a4e5785ee47f8e49497c..42e87a7e94049d5701e3a6e407470951b9ef23a3 100644 --- a/src/ptbench/utils/accelerator.py +++ b/src/ptbench/utils/accelerator.py @@ -18,8 +18,10 @@ class AcceleratorProcessor: """ def __init__(self, name): - # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now. - self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"} + # Note: "auto" is a valid accelerator in lightning, but there doesn't + # seem to be a way to check which accelerator it will actually use so + # we don't take it into account for now. + self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu", "mps": "mps"} self.lightning_to_torch = { v: k for k, v in self.torch_to_lightning.items() @@ -57,6 +59,8 @@ class AcceleratorProcessor: "Environment variable 'CUDA_VISIBLE_DEVICES' is not set." "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0" ) + elif self.accelerator == "mps": + self.device = 1 else: # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu pass