diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 82c2d19a56382ff82aaa4d5997943084a395dcc4..8b158c284da769f5f3a65994dc2bfe4fba71d78c 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -528,11 +528,11 @@ class ConcatDataModule(lightning.LightningDataModule):
         self.parallel = parallel  # immutable, otherwise would need to call
 
         self.pin_memory = (
-            torch.cuda.is_available() or torch.backends.mps.is_available()
+            torch.cuda.is_available() or torch.backends.mps.is_available()  # type: ignore
         )  # should only be true if GPU available and using it
 
         # datasets that have been setup() for the current stage
-        self._datasets: CachingDataModule.DatasetDictionary = {}
+        self._datasets: ConcatDataModule.DatasetDictionary = {}
 
     @property
     def parallel(self) -> int: