diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index a7f0a0fe32735a273ea307bbb5ceb35faf5557cc..1ae9531eb68406f3a90dcf7f4e727381945a0d39 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -10,6 +10,7 @@ import typing import lightning import torch +import torch.backends import torch.utils.data import torchvision.transforms import tqdm @@ -458,7 +459,7 @@ class CachingDataModule(lightning.LightningDataModule): self.parallel = parallel # immutable, otherwise would need to call self.pin_memory = ( - torch.cuda.is_available() + torch.cuda.is_available() or torch.backends.mps.is_available() ) # should only be true if GPU available and using it # datasets that have been setup() for the current stage