diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 440c1ea4505efe58d8bf7b17b6b48779bafa9c38..0fb4edcfab0a349546e5f7704c0783a1fe54b919 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -727,7 +727,7 @@ class CachingDataModule(lightning.LightningDataModule): self._setup_dataset("test") elif stage == "predict": - for k in self.database_split.keys(): + for k in self.database_split: self._setup_dataset(k) def teardown(self, stage: str) -> None: