diff --git a/src/ptbench/data/montgomery_shenzhen/datamodule.py b/src/ptbench/data/montgomery_shenzhen/datamodule.py
index 0335a2e691ed2fcf10378e501d7715aa474c4475..e1173824eaf5f4a1f14712986abaeb304bcabf5b 100644
--- a/src/ptbench/data/montgomery_shenzhen/datamodule.py
+++ b/src/ptbench/data/montgomery_shenzhen/datamodule.py
@@ -14,9 +14,9 @@ class DataModule(ConcatDataModule):
 
     def __init__(self, split_filename: str):
         montgomery_loader = MontgomeryLoader()
-        montgomery_split = make_montgomery_split("default.json")
+        montgomery_split = make_montgomery_split(split_filename)
         shenzhen_loader = ShenzhenLoader()
-        shenzhen_split = make_shenzhen_split("default.json")
+        shenzhen_split = make_shenzhen_split(split_filename)
 
         super().__init__(
             splits={