Skip to content
Snippets Groups Projects

Making use of LightningDataModule and simplification of data loading

Merged Daniel CARRON requested to merge add-datamodule into main
1 file
+ 2
1
Compare changes
  • Side-by-side
  • Inline
@@ -10,6 +10,7 @@ import typing
@@ -10,6 +10,7 @@ import typing
import lightning
import lightning
import torch
import torch
 
import torch.backends
import torch.utils.data
import torch.utils.data
import torchvision.transforms
import torchvision.transforms
import tqdm
import tqdm
@@ -458,7 +459,7 @@ class CachingDataModule(lightning.LightningDataModule):
@@ -458,7 +459,7 @@ class CachingDataModule(lightning.LightningDataModule):
self.parallel = parallel # immutable, otherwise would need to call
self.parallel = parallel # immutable, otherwise would need to call
self.pin_memory = (
self.pin_memory = (
torch.cuda.is_available()
torch.cuda.is_available() or torch.backends.mps.is_available()
) # should only be true if GPU available and using it
) # should only be true if GPU available and using it
# datasets that have been setup() for the current stage
# datasets that have been setup() for the current stage
Loading