diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 440c1ea4505efe58d8bf7b17b6b48779bafa9c38..0fb4edcfab0a349546e5f7704c0783a1fe54b919 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -727,7 +727,7 @@ class CachingDataModule(lightning.LightningDataModule):
             self._setup_dataset("test")
 
         elif stage == "predict":
-            for k in self.database_split.keys():
+            for k in self.database_split:
                 self._setup_dataset(k)
 
     def teardown(self, stage: str) -> None: