From 37f89a95e0e481ea152d81596d99defe34c33f4e Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 2 Aug 2023 11:57:50 +0200
Subject: [PATCH] [data.montgomery_shenzhen_indian_padchest] Port to new
 lightning infrastructure

---
 pyproject.toml                                |  2 +-
 src/ptbench/data/mc_ch_in_pc/__init__.py      |  3 -
 src/ptbench/data/mc_ch_in_pc/default.py       | 90 -------------------
 .../__init__.py                               |  0
 .../datamodule.py                             | 50 +++++++++++
 .../default.py                                |  9 ++
 6 files changed, 60 insertions(+), 94 deletions(-)
 delete mode 100644 src/ptbench/data/mc_ch_in_pc/__init__.py
 delete mode 100644 src/ptbench/data/mc_ch_in_pc/default.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen_indian_padchest/__init__.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen_indian_padchest/datamodule.py
 create mode 100644 src/ptbench/data/montgomery_shenzhen_indian_padchest/default.py

diff --git a/pyproject.toml b/pyproject.toml
index 846cf96b..30531b7c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -231,7 +231,7 @@ padchest-cardiomegaly-idiap = "ptbench.data.padchest.cardiomegaly_idiap"
 nih-cxr14-padchest = "ptbench.data.nih_cxr14_padchest.idiap"
 
 # montgomery-shenzhen-indian-padchest aggregated dataset
-mc_ch_in_pc = "ptbench.data.mc_ch_in_pc.default"
+montgomery-shenzhen-indian-padchest = "ptbench.data.montgomery_shenzhen_indian_padchest.default"
 
 [tool.setuptools]
 zip-safe = true
diff --git a/src/ptbench/data/mc_ch_in_pc/__init__.py b/src/ptbench/data/mc_ch_in_pc/__init__.py
deleted file mode 100644
index 662d5c13..00000000
--- a/src/ptbench/data/mc_ch_in_pc/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
diff --git a/src/ptbench/data/mc_ch_in_pc/default.py b/src/ptbench/data/mc_ch_in_pc/default.py
deleted file mode 100644
index 4715b13d..00000000
--- a/src/ptbench/data/mc_ch_in_pc/default.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest
-datasets."""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..indian.default import datamodule as indian_datamodule
-from ..montgomery.default import datamodule as mc_datamodule
-from ..padchest.tb_idiap import datamodule as pc_datamodule
-from ..shenzhen.default import datamodule as ch_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        # Instantiate other datamodules and get their datasets
-
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
-        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
-        indian = get_dataset_from_module(
-            indian_datamodule, stage, **module_args
-        )
-        pc = get_dataset_from_module(pc_datamodule, stage, **module_args)
-
-        # Combine datasets
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [
-                mc["__train__"],
-                ch["__train__"],
-                indian["__train__"],
-                pc["__train__"],
-            ]
-        )
-        self.dataset["train"] = ConcatDataset(
-            [mc["train"], ch["train"], indian["train"], pc["train"]]
-        )
-        self.dataset["__valid__"] = ConcatDataset(
-            [
-                mc["__valid__"],
-                ch["__valid__"],
-                indian["__valid__"],
-                pc["__valid__"],
-            ]
-        )
-        self.dataset["test"] = ConcatDataset(
-            [mc["test"], ch["test"], indian["test"], pc["test"]]
-        )
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
diff --git a/src/ptbench/data/montgomery_shenzhen_indian_padchest/__init__.py b/src/ptbench/data/montgomery_shenzhen_indian_padchest/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/ptbench/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/ptbench/data/montgomery_shenzhen_indian_padchest/datamodule.py
new file mode 100644
index 00000000..2fdcfc67
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen_indian_padchest/datamodule.py
@@ -0,0 +1,50 @@
+# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from ..datamodule import ConcatDataModule
+from ..indian.datamodule import RawDataLoader as IndianLoader
+from ..indian.datamodule import make_split as make_indian_split
+from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
+from ..montgomery.datamodule import make_split as make_montgomery_split
+from ..padchest.datamodule import RawDataLoader as PadchestLoader
+from ..padchest.datamodule import make_split as make_padchest_split
+from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader
+from ..shenzhen.datamodule import make_split as make_shenzhen_split
+
+
+class DataModule(ConcatDataModule):
+    """Aggregated datamodule composed of Montgomery and Shenzhen datasets."""
+
+    def __init__(self, split_filename: str, padchest_split_filename: str):
+        montgomery_loader = MontgomeryLoader()
+        montgomery_split = make_montgomery_split(split_filename)
+        shenzhen_loader = ShenzhenLoader()
+        shenzhen_split = make_shenzhen_split(split_filename)
+        indian_loader = IndianLoader()
+        indian_split = make_indian_split(split_filename)
+        padchest_loader = PadchestLoader()
+        padchest_split = make_padchest_split(padchest_split_filename)
+
+        super().__init__(
+            splits={
+                "train": [
+                    (montgomery_split["train"], montgomery_loader),
+                    (shenzhen_split["train"], shenzhen_loader),
+                    (indian_split["train"], indian_loader),
+                    (padchest_split["train"], padchest_loader),
+                ],
+                "validation": [
+                    (montgomery_split["validation"], montgomery_loader),
+                    (shenzhen_split["validation"], shenzhen_loader),
+                    (indian_split["validation"], indian_loader),
+                    (padchest_split["validation"], padchest_loader),
+                ],
+                "test": [
+                    (montgomery_split["test"], montgomery_loader),
+                    (shenzhen_split["test"], shenzhen_loader),
+                    (indian_split["test"], indian_loader),
+                    (padchest_split["test"], padchest_loader),
+                ],
+            }
+        )
diff --git a/src/ptbench/data/montgomery_shenzhen_indian_padchest/default.py b/src/ptbench/data/montgomery_shenzhen_indian_padchest/default.py
new file mode 100644
index 00000000..5368466a
--- /dev/null
+++ b/src/ptbench/data/montgomery_shenzhen_indian_padchest/default.py
@@ -0,0 +1,9 @@
+# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("default.json", "tb-idiap.json")
+"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest
+datasets."""
-- 
GitLab