From 0632bb724404585e87bec878a3a2985a79c129ff Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 2 Aug 2023 11:37:02 +0200
Subject: [PATCH] [data.nih_cxr14_padchest] Reimplements aggregated database

---
 pyproject.toml                                | 12 +--
 .../data/nih_cxr14_padchest/__init__.py       |  0
 .../data/nih_cxr14_padchest/datamodule.py     | 37 +++++++++
 .../idiap.py}                                 |  4 +
 src/ptbench/data/nih_cxr14_re_pc/idiap.py     | 76 -------------------
 5 files changed, 47 insertions(+), 82 deletions(-)
 create mode 100644 src/ptbench/data/nih_cxr14_padchest/__init__.py
 create mode 100644 src/ptbench/data/nih_cxr14_padchest/datamodule.py
 rename src/ptbench/data/{nih_cxr14_re_pc/__init__.py => nih_cxr14_padchest/idiap.py} (56%)
 delete mode 100644 src/ptbench/data/nih_cxr14_re_pc/idiap.py

diff --git a/pyproject.toml b/pyproject.toml
index 23a5908c..915717ae 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -193,7 +193,7 @@ montgomery-shenzhen-indian-tbx11k-v2-f7 = "ptbench.data.montgomery_shenzhen_indi
 montgomery-shenzhen-indian-tbx11k-v2-f8 = "ptbench.data.montgomery_shenzhen_indian_tbx11k.v2_fold_8"
 montgomery-shenzhen-indian-tbx11k-v2-f9 = "ptbench.data.montgomery_shenzhen_indian_tbx11k.v2_fold_9"
 
-# tbpoc dataset (and cross-validation folds)
+# tbpoc dataset (only cross-validation folds)
 tbpoc_f0 = "ptbench.data.tbpoc.fold_0"
 tbpoc_f1 = "ptbench.data.tbpoc.fold_1"
 tbpoc_f2 = "ptbench.data.tbpoc.fold_2"
@@ -205,7 +205,7 @@ tbpoc_f7 = "ptbench.data.tbpoc.fold_7"
 tbpoc_f8 = "ptbench.data.tbpoc.fold_8"
 tbpoc_f9 = "ptbench.data.tbpoc.fold_9"
 
-# hivtb dataset (and cross-validation folds)
+# hivtb dataset (only cross-validation folds)
 hivtb_f0 = "ptbench.data.hivtb.fold_0"
 hivtb_f1 = "ptbench.data.hivtb.fold_1"
 hivtb_f2 = "ptbench.data.hivtb.fold_2"
@@ -217,9 +217,6 @@ hivtb_f7 = "ptbench.data.hivtb.fold_7"
 hivtb_f8 = "ptbench.data.hivtb.fold_8"
 hivtb_f9 = "ptbench.data.hivtb.fold_9"
 
-# montgomery-shenzhen-indian-padchest aggregated dataset
-mc_ch_in_pc = "ptbench.data.mc_ch_in_pc.default"
-
 # NIH CXR14 (relabeled), multi-class (14 labels)
 nih-cxr14 = "ptbench.data.nih_cxr14.default"
 nih-cxr14-cardiomegaly = "ptbench.data.nih_cxr14.cardiomegaly"
@@ -231,7 +228,10 @@ padchest-no-tb-idiap = "ptbench.data.padchest.no_tb_idiap"
 padchest-cardiomegaly-idiap = "ptbench.data.padchest.cardiomegaly_idiap"
 
 # NIH CXR14 / PadChest aggregated dataset
-nih_cxr14_pc_idiap = "ptbench.data.nih_cxr14_re_pc.idiap"
+nih-cxr14-padchest = "ptbench.data.nih_cxr14_padchest.idiap"
+
+# montgomery-shenzhen-indian-padchest aggregated dataset
+mc_ch_in_pc = "ptbench.data.mc_ch_in_pc.default"
 
 [tool.setuptools]
 zip-safe = true
diff --git a/src/ptbench/data/nih_cxr14_padchest/__init__.py b/src/ptbench/data/nih_cxr14_padchest/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/ptbench/data/nih_cxr14_padchest/datamodule.py b/src/ptbench/data/nih_cxr14_padchest/datamodule.py
new file mode 100644
index 00000000..335679bf
--- /dev/null
+++ b/src/ptbench/data/nih_cxr14_padchest/datamodule.py
@@ -0,0 +1,37 @@
+# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from ..datamodule import ConcatDataModule
+from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader
+from ..nih_cxr14.datamodule import make_split as make_cxr14_split
+from ..padchest.datamodule import RawDataLoader as PadchestLoader
+from ..padchest.datamodule import make_split as make_padchest_split
+
+
+class DataModule(ConcatDataModule):
+    """Aggregated dataset composed of NIH CXR14 relabeld and PadChest
+    (normalized) datasets."""
+
+    def __init__(self, cxr14_split_filename: str, padchest_split_filename):
+        cxr14_loader = CXR14Loader()
+        cxr14_split = make_cxr14_split(cxr14_split_filename)
+        padchest_loader = PadchestLoader()
+        padchest_split = make_padchest_split(padchest_split_filename)
+
+        super().__init__(
+            splits={
+                "train": [
+                    (cxr14_split["train"], cxr14_loader),
+                    (padchest_split["train"], padchest_loader),
+                ],
+                "validation": [
+                    (cxr14_split["validation"], cxr14_loader),
+                    (padchest_split["validation"], padchest_loader),
+                ],
+                "test": [
+                    (cxr14_split["test"], cxr14_loader),
+                    (padchest_split["test"], padchest_loader),
+                ],
+            }
+        )
diff --git a/src/ptbench/data/nih_cxr14_re_pc/__init__.py b/src/ptbench/data/nih_cxr14_padchest/idiap.py
similarity index 56%
rename from src/ptbench/data/nih_cxr14_re_pc/__init__.py
rename to src/ptbench/data/nih_cxr14_padchest/idiap.py
index 84b9088e..45754ea6 100644
--- a/src/ptbench/data/nih_cxr14_re_pc/__init__.py
+++ b/src/ptbench/data/nih_cxr14_padchest/idiap.py
@@ -1,3 +1,7 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
+from .datamodule import DataModule
+
+datamodule = DataModule("default.json.bz2", "no-tb-idiap.json.bz2")
diff --git a/src/ptbench/data/nih_cxr14_re_pc/idiap.py b/src/ptbench/data/nih_cxr14_re_pc/idiap.py
deleted file mode 100644
index 72a9c466..00000000
--- a/src/ptbench/data/nih_cxr14_re_pc/idiap.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized)
-datasets."""
-
-from clapper.logging import setup
-from torch.utils.data.dataset import ConcatDataset
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule, get_dataset_from_module
-from ..nih_cxr14_re.default import datamodule as nih_cxr14_re_datamodule
-from ..padchest.no_tb_idiap import datamodule as padchest_no_tb_datamodule
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        self.train_batch_size = train_batch_size
-        self.predict_batch_size = predict_batch_size
-        self.drop_incomplete_batch = drop_incomplete_batch
-        self.multiproc_kwargs = multiproc_kwargs
-
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        module_args = {
-            "train_batch_size": self.train_batch_size,
-            "predict_batch_size": self.predict_batch_size,
-            "drop_incomplete_batch": self.drop_incomplete_batch,
-            "multiproc_kwargs": self.multiproc_kwargs,
-        }
-
-        nih_cxr14_re = get_dataset_from_module(
-            nih_cxr14_re_datamodule, stage, **module_args
-        )
-        padchest_no_tb = get_dataset_from_module(
-            padchest_no_tb_datamodule, stage, **module_args
-        )
-
-        self.dataset = {}
-        self.dataset["__train__"] = ConcatDataset(
-            [nih_cxr14_re["__train__"], padchest_no_tb["__train__"]]
-        )
-        self.dataset["train"] = ConcatDataset(
-            [nih_cxr14_re["train"], padchest_no_tb["train"]]
-        )
-        self.dataset["__valid__"] = ConcatDataset(
-            [nih_cxr14_re["__valid__"], padchest_no_tb["__valid__"]]
-        )
-        self.dataset["validation"] = ConcatDataset(
-            [nih_cxr14_re["validation"], padchest_no_tb["validation"]]
-        )
-        self.dataset["test"] = nih_cxr14_re["test"]
-
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
-- 
GitLab