diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 82c2d19a56382ff82aaa4d5997943084a395dcc4..8b158c284da769f5f3a65994dc2bfe4fba71d78c 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -528,11 +528,11 @@ class ConcatDataModule(lightning.LightningDataModule): self.parallel = parallel # immutable, otherwise would need to call self.pin_memory = ( - torch.cuda.is_available() or torch.backends.mps.is_available() + torch.cuda.is_available() or torch.backends.mps.is_available() # type: ignore ) # should only be true if GPU available and using it # datasets that have been setup() for the current stage - self._datasets: CachingDataModule.DatasetDictionary = {} + self._datasets: ConcatDataModule.DatasetDictionary = {} @property def parallel(self) -> int: