From a518c23dc1a14569b18a2d5e3586076e20b2e2d2 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 10 Jul 2023 17:57:52 +0200 Subject: [PATCH] [ptbench.data.data.module] Pin memory when using MPS as well --- src/ptbench/data/datamodule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index a7f0a0fe..1ae9531e 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -10,6 +10,7 @@ import typing import lightning import torch +import torch.backends import torch.utils.data import torchvision.transforms import tqdm @@ -458,7 +459,7 @@ class CachingDataModule(lightning.LightningDataModule): self.parallel = parallel # immutable, otherwise would need to call self.pin_memory = ( - torch.cuda.is_available() + torch.cuda.is_available() or torch.backends.mps.is_available() ) # should only be true if GPU available and using it # datasets that have been setup() for the current stage -- GitLab