diff --git a/src/ptbench/data/montgomery_shenzhen/datamodule.py b/src/ptbench/data/montgomery_shenzhen/datamodule.py index 0335a2e691ed2fcf10378e501d7715aa474c4475..e1173824eaf5f4a1f14712986abaeb304bcabf5b 100644 --- a/src/ptbench/data/montgomery_shenzhen/datamodule.py +++ b/src/ptbench/data/montgomery_shenzhen/datamodule.py @@ -14,9 +14,9 @@ class DataModule(ConcatDataModule): def __init__(self, split_filename: str): montgomery_loader = MontgomeryLoader() - montgomery_split = make_montgomery_split("default.json") + montgomery_split = make_montgomery_split(split_filename) shenzhen_loader = ShenzhenLoader() - shenzhen_split = make_shenzhen_split("default.json") + shenzhen_split = make_shenzhen_split(split_filename) super().__init__( splits={