diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index a7f0a0fe32735a273ea307bbb5ceb35faf5557cc..1ae9531eb68406f3a90dcf7f4e727381945a0d39 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -10,6 +10,7 @@ import typing
 
 import lightning
 import torch
+import torch.backends
 import torch.utils.data
 import torchvision.transforms
 import tqdm
@@ -458,7 +459,7 @@ class CachingDataModule(lightning.LightningDataModule):
         self.parallel = parallel  # immutable, otherwise would need to call
 
         self.pin_memory = (
-            torch.cuda.is_available()
+            torch.cuda.is_available() or torch.backends.mps.is_available()
         )  # should only be true if GPU available and using it
 
         # datasets that have been setup() for the current stage