From 748acd57a109e1c25cec7ee96846e40316999c2d Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 27 Jul 2023 21:26:08 +0200
Subject: [PATCH] [data.montgomery_shenzhen] Create first concatenated
 datamodule

---
 src/ptbench/data/mc_ch/__init__.py            |  3 -
 src/ptbench/data/mc_ch/default.py             | 73 ------------------
 src/ptbench/data/mc_ch/fold_0.py              | 74 -------------------
 src/ptbench/data/mc_ch/fold_1.py              | 74 -------------------
 src/ptbench/data/mc_ch/fold_2.py              | 73 ------------------
 src/ptbench/data/mc_ch/fold_3.py              | 73 ------------------
 src/ptbench/data/mc_ch/fold_4.py              | 73 ------------------
 src/ptbench/data/mc_ch/fold_5.py              | 73 ------------------
 src/ptbench/data/mc_ch/fold_6.py              | 73 ------------------
 src/ptbench/data/mc_ch/fold_7.py              | 73 ------------------
 src/ptbench/data/mc_ch/fold_8.py              | 73 ------------------
 src/ptbench/data/mc_ch/fold_9.py              | 73 ------------------
 .../data/montgomery_shenzhen/__init__.py      |  0
 .../data/montgomery_shenzhen/datamodule.py    | 36 +++++++++
 .../data/montgomery_shenzhen/default.py       |  7 ++
 .../data/montgomery_shenzhen/fold_0.py        |  7 ++
 .../data/montgomery_shenzhen/fold_1.py        |  7 ++
 .../data/montgomery_shenzhen/fold_2.py        |  7 ++
 .../data/montgomery_shenzhen/fold_3.py        |  7 ++
 .../data/montgomery_shenzhen/fold_4.py        |  7 ++
 .../data/montgomery_shenzhen/fold_5.py        |  7 ++
 .../data/montgomery_shenzhen/fold_6.py        |  7 ++
 .../data/montgomery_shenzhen/fold_7.py        |  7 ++
 .../data/montgomery_shenzhen/fold_8.py        |  7 ++
 .../data/montgomery_shenzhen/fold_9.py        |  7 ++
 25 files changed, 113 insertions(+), 808 deletions(-)
 delete mode 100644 src/ptbench/data/mc_ch/__init__.py
 delete mode 100644 src/ptbench/data/mc_ch/default.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_0.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_1.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_2.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_3.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_4.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_5.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_6.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_7.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_8.py
 delete mode 100644 src/ptbench/data/mc_ch/fold_9.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/__init__.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/datamodule.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/default.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_0.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_1.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_2.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_3.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_4.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_5.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_6.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_7.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_8.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen/fold_9.py

diff --git a/src/ptbench/data/mc_ch/__init__.py b/src/ptbench/data/mc_ch/__init__.py
deleted file mode 100644
index 662d5c13..00000000
--- a/src/ptbench/data/mc_ch/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
diff --git a/src/ptbench/data/mc_ch/default.py b/src/ptbench/data/mc_ch/default.py
deleted file mode 100644
index cf901fdb..00000000
--- a/src/ptbench/data/mc_ch/default.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets."""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.default import datamodule as mc_datamodule
-from ..shenzhen.default import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_0.py b/src/ptbench/data/mc_ch/fold_0.py
deleted file mode 100644
index d3717d6f..00000000
--- a/src/ptbench/data/mc_ch/fold_0.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 0)"""
-
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_0 import datamodule as mc_datamodule
-from ..shenzhen.fold_0 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_1.py b/src/ptbench/data/mc_ch/fold_1.py
deleted file mode 100644
index 2fb8136b..00000000
--- a/src/ptbench/data/mc_ch/fold_1.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 1)"""
-
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_1 import datamodule as mc_datamodule
-from ..shenzhen.fold_1 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_2.py b/src/ptbench/data/mc_ch/fold_2.py
deleted file mode 100644
index 4ef0dd55..00000000
--- a/src/ptbench/data/mc_ch/fold_2.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 2)"""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_2 import datamodule as mc_datamodule
-from ..shenzhen.fold_2 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_3.py b/src/ptbench/data/mc_ch/fold_3.py
deleted file mode 100644
index 3c86f11c..00000000
--- a/src/ptbench/data/mc_ch/fold_3.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 3)"""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_3 import datamodule as mc_datamodule
-from ..shenzhen.fold_3 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_4.py b/src/ptbench/data/mc_ch/fold_4.py
deleted file mode 100644
index c0e08532..00000000
--- a/src/ptbench/data/mc_ch/fold_4.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 4)"""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_4 import datamodule as mc_datamodule
-from ..shenzhen.fold_4 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_5.py b/src/ptbench/data/mc_ch/fold_5.py
deleted file mode 100644
index ca205eb3..00000000
--- a/src/ptbench/data/mc_ch/fold_5.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 5)"""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_5 import datamodule as mc_datamodule
-from ..shenzhen.fold_5 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_6.py b/src/ptbench/data/mc_ch/fold_6.py
deleted file mode 100644
index 42a6bb58..00000000
--- a/src/ptbench/data/mc_ch/fold_6.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 6)"""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_6 import datamodule as mc_datamodule
-from ..shenzhen.fold_6 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_7.py b/src/ptbench/data/mc_ch/fold_7.py
deleted file mode 100644
index 082c0267..00000000
--- a/src/ptbench/data/mc_ch/fold_7.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 7)"""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_7 import datamodule as mc_datamodule
-from ..shenzhen.fold_7 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_8.py b/src/ptbench/data/mc_ch/fold_8.py
deleted file mode 100644
index 814fd9a4..00000000
--- a/src/ptbench/data/mc_ch/fold_8.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 8)"""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_8 import datamodule as mc_datamodule
-from ..shenzhen.fold_8 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/mc_ch/fold_9.py b/src/ptbench/data/mc_ch/fold_9.py
deleted file mode 100644
index a1d564e6..00000000
--- a/src/ptbench/data/mc_ch/fold_9.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
-validation fold 9)"""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..montgomery.fold_9 import datamodule as mc_datamodule
-from ..shenzhen.fold_9 import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [mc["__train__"], ch["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
-        self.dataset["__valid__"] = ConcatDataset(
-            [mc["__valid__"], ch["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [mc["validation"], ch["validation"]]
-        )
-        self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery_shenzhen/__init__.py b/src/ptbench/data/montgomery_shenzhen/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/ptbench/data/montgomery_shenzhen/datamodule.py b/src/ptbench/data/montgomery_shenzhen/datamodule.py
new file mode 100644
index 00000000..35ba9b9a
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/datamodule.py
@@ -0,0 +1,36 @@
+# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from ..datamodule import ConcatDataModule
+from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
+from ..montgomery.datamodule import make_split as make_montgomery_split
+from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader
+from ..shenzhen.datamodule import make_split as make_shenzhen_split
+
+
+class DataModule(ConcatDataModule):
+    """Aggregated datamodule composed of Montgomery and Shenzhen datasets."""
+
+    def __init__(self, split_filename: str):
+        montgomery_loader = MontgomeryLoader()
+        montgomery_split = make_montgomery_split("default.json")
+        shenzen_loader = ShenzhenLoader()
+        shenzen_split = make_shenzhen_split("default.json")
+
+        super().__init__(
+            splits={
+                "train": [
+                    (montgomery_split["train"], montgomery_loader),
+                    (shenzen_split["train"], shenzen_loader),
+                ],
+                "validation": [
+                    (montgomery_split["validation"], montgomery_loader),
+                    (shenzen_split["validation"], shenzen_loader),
+                ],
+                "test": [
+                    (montgomery_split["test"], montgomery_loader),
+                    (shenzen_split["test"], shenzen_loader),
+                ],
+            }
+        )
diff --git a/src/ptbench/data/montgomery_shenzhen/default.py b/src/ptbench/data/montgomery_shenzhen/default.py
new file mode 100644
index 00000000..2b8a8fb2
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/default.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("default.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_0.py b/src/ptbench/data/montgomery_shenzhen/fold_0.py
new file mode 100644
index 00000000..3d114d07
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_0.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_0.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_1.py b/src/ptbench/data/montgomery_shenzhen/fold_1.py
new file mode 100644
index 00000000..cd3a8cb6
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_1.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_1.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_2.py b/src/ptbench/data/montgomery_shenzhen/fold_2.py
new file mode 100644
index 00000000..44eeda80
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_2.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_2.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_3.py b/src/ptbench/data/montgomery_shenzhen/fold_3.py
new file mode 100644
index 00000000..f24fb314
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_3.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_3.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_4.py b/src/ptbench/data/montgomery_shenzhen/fold_4.py
new file mode 100644
index 00000000..58456d38
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_4.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_4.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_5.py b/src/ptbench/data/montgomery_shenzhen/fold_5.py
new file mode 100644
index 00000000..92796746
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_5.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_5.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_6.py b/src/ptbench/data/montgomery_shenzhen/fold_6.py
new file mode 100644
index 00000000..9566b7cf
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_6.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_6.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_7.py b/src/ptbench/data/montgomery_shenzhen/fold_7.py
new file mode 100644
index 00000000..25cbfe1b
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_7.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_7.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_8.py b/src/ptbench/data/montgomery_shenzhen/fold_8.py
new file mode 100644
index 00000000..fb5332ce
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_8.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_8.json")
diff --git a/src/ptbench/data/montgomery_shenzhen/fold_9.py b/src/ptbench/data/montgomery_shenzhen/fold_9.py
new file mode 100644
index 00000000..d1626586
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen/fold_9.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("fold_9.json")
-- 
GitLab