Skip to content
Snippets Groups Projects
Commit a0f264f0 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[ptbench.utils.accelerator] Add support for mps backend

parent 99f52320
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
This commit is part of merge request !6. Comments created here will be created in the context of that merge request.
...@@ -18,8 +18,10 @@ class AcceleratorProcessor: ...@@ -18,8 +18,10 @@ class AcceleratorProcessor:
""" """
def __init__(self, name): def __init__(self, name):
# Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now. # Note: "auto" is a valid accelerator in lightning, but there doesn't
self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"} # seem to be a way to check which accelerator it will actually use so
# we don't take it into account for now.
self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu", "mps": "mps"}
self.lightning_to_torch = { self.lightning_to_torch = {
v: k for k, v in self.torch_to_lightning.items() v: k for k, v in self.torch_to_lightning.items()
...@@ -57,6 +59,8 @@ class AcceleratorProcessor: ...@@ -57,6 +59,8 @@ class AcceleratorProcessor:
"Environment variable 'CUDA_VISIBLE_DEVICES' is not set." "Environment variable 'CUDA_VISIBLE_DEVICES' is not set."
"Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0" "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0"
) )
elif self.accelerator == "mps":
self.device = 1
else: else:
# No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu
pass pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment