From a518c23dc1a14569b18a2d5e3586076e20b2e2d2 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 10 Jul 2023 17:57:52 +0200
Subject: [PATCH] [ptbench.data.data.module] Pin memory when using MPS as well

---
 src/ptbench/data/datamodule.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index a7f0a0fe..1ae9531e 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -10,6 +10,7 @@ import typing
 
 import lightning
 import torch
+import torch.backends
 import torch.utils.data
 import torchvision.transforms
 import tqdm
@@ -458,7 +459,7 @@ class CachingDataModule(lightning.LightningDataModule):
         self.parallel = parallel  # immutable, otherwise would need to call
 
         self.pin_memory = (
-            torch.cuda.is_available()
+            torch.cuda.is_available() or torch.backends.mps.is_available()
         )  # should only be true if GPU available and using it
 
         # datasets that have been setup() for the current stage
-- 
GitLab