Skip to content
Snippets Groups Projects
Commit a518c23d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[ptbench.data.data.module] Pin memory when using MPS as well

parent c25e5008
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
This commit is part of merge request !6. Comments created here will be created in the context of that merge request.
......@@ -10,6 +10,7 @@ import typing
import lightning
import torch
import torch.backends
import torch.utils.data
import torchvision.transforms
import tqdm
......@@ -458,7 +459,7 @@ class CachingDataModule(lightning.LightningDataModule):
self.parallel = parallel # immutable, otherwise would need to call
self.pin_memory = (
torch.cuda.is_available()
torch.cuda.is_available() or torch.backends.mps.is_available()
) # should only be true if GPU available and using it
# datasets that have been setup() for the current stage
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment