From dd4dd5d1bf072f9adb880ca39fb2f83c0f00610d Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 6 Jun 2023 15:29:18 +0200
Subject: [PATCH] Moved nih_cxr14_re configs to data

---
 pyproject.toml                                |  6 +--
 .../configs/datasets/nih_cxr14_re/__init__.py | 22 ---------
 .../datasets/nih_cxr14_re/cardiomegaly.py     | 16 -------
 .../configs/datasets/nih_cxr14_re/default.py  | 15 ------
 src/ptbench/data/nih_cxr14_re/__init__.py     | 20 +++++++-
 src/ptbench/data/nih_cxr14_re/cardiomegaly.py | 48 +++++++++++++++++++
 src/ptbench/data/nih_cxr14_re/default.py      | 47 ++++++++++++++++++
 7 files changed, 117 insertions(+), 57 deletions(-)
 delete mode 100644 src/ptbench/configs/datasets/nih_cxr14_re/__init__.py
 delete mode 100644 src/ptbench/configs/datasets/nih_cxr14_re/cardiomegaly.py
 delete mode 100644 src/ptbench/configs/datasets/nih_cxr14_re/default.py
 create mode 100644 src/ptbench/data/nih_cxr14_re/cardiomegaly.py
 create mode 100644 src/ptbench/data/nih_cxr14_re/default.py

diff --git a/pyproject.toml b/pyproject.toml
index a9b1abf1..e43899d2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -471,10 +471,10 @@ mc_ch_in_pc_rgb = "ptbench.data.mc_ch_in_pc.rgb"
 # (with radiological signs)
 mc_ch_in_pc_rs = "ptbench.configs.datasets.mc_ch_in_pc_RS.default"
 # NIH CXR14 (relabeled)
-nih_cxr14 = "ptbench.configs.datasets.nih_cxr14_re.default"
-nih_cxr14_cm = "ptbench.configs.datasets.nih_cxr14_re.cardiomegaly"
+nih_cxr14 = "ptbench.data.nih_cxr14_re.default"
+nih_cxr14_cm = "ptbench.data.nih_cxr14_re.cardiomegaly"
 # NIH CXR14 / PadChest aggregated dataset
-nih_cxr14_pc_idiap = "ptbench.configs.datasets.nih_cxr14_re_pc.idiap"
+nih_cxr14_pc_idiap = "ptbench.data.nih_cxr14_re_pc.idiap"
 # PadChest
 padchest_idiap = "ptbench.data.padchest.idiap"
 padchest_tb_idiap = "ptbench.data.padchest.tb_idiap"
diff --git a/src/ptbench/configs/datasets/nih_cxr14_re/__init__.py b/src/ptbench/configs/datasets/nih_cxr14_re/__init__.py
deleted file mode 100644
index d0e7117c..00000000
--- a/src/ptbench/configs/datasets/nih_cxr14_re/__init__.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-
-def _maker(protocol, size=512):
-    import torchvision.transforms as transforms
-
-    from ....data.nih_cxr14_re import dataset as raw
-    from .. import make_dataset as mk
-
-    # ImageNet normalization
-    normalize = transforms.Normalize(
-        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
-    )
-
-    return mk(
-        [raw.subsets(protocol)],
-        [transforms.Resize((size, size))],
-        [transforms.RandomHorizontalFlip()],
-        [transforms.ToTensor(), normalize],
-    )
diff --git a/src/ptbench/configs/datasets/nih_cxr14_re/cardiomegaly.py b/src/ptbench/configs/datasets/nih_cxr14_re/cardiomegaly.py
deleted file mode 100644
index 0a63a6fd..00000000
--- a/src/ptbench/configs/datasets/nih_cxr14_re/cardiomegaly.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""NIH CXR14 dataset for computer-aided diagnosis.
-
-First 40 images with cardiomegaly.
-
-* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
-"""
-
-from . import _maker
-
-dataset = _maker("cardiomegaly")
diff --git a/src/ptbench/configs/datasets/nih_cxr14_re/default.py b/src/ptbench/configs/datasets/nih_cxr14_re/default.py
deleted file mode 100644
index c1e472d0..00000000
--- a/src/ptbench/configs/datasets/nih_cxr14_re/default.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""NIH CXR14 (relabeled) dataset for computer-aided diagnosis (default
-protocol)
-
-* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
-"""
-
-from . import _maker
-
-dataset = _maker("default")
diff --git a/src/ptbench/data/nih_cxr14_re/__init__.py b/src/ptbench/data/nih_cxr14_re/__init__.py
index 8dc3f223..27d1903c 100644
--- a/src/ptbench/data/nih_cxr14_re/__init__.py
+++ b/src/ptbench/data/nih_cxr14_re/__init__.py
@@ -72,9 +72,27 @@ def _loader(context, sample):
     return make_delayed(sample, _raw_data_loader)
 
 
-dataset = JSONDataset(
+json_dataset = JSONDataset(
     protocols=_protocols,
     fieldnames=("data", "label"),
     loader=_loader,
 )
 """NIH CXR14 (relabeled) dataset object."""
+
+
+def _maker(protocol, size=512):
+    import torchvision.transforms as transforms
+
+    from .. import make_dataset
+
+    # ImageNet normalization
+    normalize = transforms.Normalize(
+        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+    )
+
+    return make_dataset(
+        [json_dataset.subsets(protocol)],
+        [transforms.Resize((size, size))],
+        [transforms.RandomHorizontalFlip()],
+        [transforms.ToTensor(), normalize],
+    )
diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py
new file mode 100644
index 00000000..1904ebfa
--- /dev/null
+++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py
@@ -0,0 +1,48 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""NIH CXR14 dataset for computer-aided diagnosis.
+
+First 40 images with cardiomegaly.
+
+* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
+* This configuration resolution: 512 x 512 (default)
+* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
+"""
+
+from clapper.logging import setup
+
+from .. import return_subsets
+from ..base_datamodule import BaseDataModule
+from . import _maker
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+class Fold0Module(BaseDataModule):
+    def __init__(
+        self,
+        train_batch_size=1,
+        predict_batch_size=1,
+        drop_incomplete_batch=False,
+        multiproc_kwargs=None,
+    ):
+        super().__init__(
+            train_batch_size=train_batch_size,
+            predict_batch_size=predict_batch_size,
+            drop_incomplete_batch=drop_incomplete_batch,
+            multiproc_kwargs=multiproc_kwargs,
+        )
+
+    def setup(self, stage: str):
+        self.dataset = _maker("cardiomegaly")
+        (
+            self.train_dataset,
+            self.validation_dataset,
+            self.extra_validation_datasets,
+            self.predict_dataset,
+        ) = return_subsets(self.dataset)
+
+
+datamodule = Fold0Module
diff --git a/src/ptbench/data/nih_cxr14_re/default.py b/src/ptbench/data/nih_cxr14_re/default.py
new file mode 100644
index 00000000..0ea6ef5a
--- /dev/null
+++ b/src/ptbench/data/nih_cxr14_re/default.py
@@ -0,0 +1,47 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""NIH CXR14 (relabeled) dataset for computer-aided diagnosis (default
+protocol)
+
+* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
+* This configuration resolution: 512 x 512 (default)
+* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
+"""
+
+from clapper.logging import setup
+
+from .. import return_subsets
+from ..base_datamodule import BaseDataModule
+from . import _maker
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+class DefaultModule(BaseDataModule):
+    def __init__(
+        self,
+        train_batch_size=1,
+        predict_batch_size=1,
+        drop_incomplete_batch=False,
+        multiproc_kwargs=None,
+    ):
+        super().__init__(
+            train_batch_size=train_batch_size,
+            predict_batch_size=predict_batch_size,
+            drop_incomplete_batch=drop_incomplete_batch,
+            multiproc_kwargs=multiproc_kwargs,
+        )
+
+    def setup(self, stage: str):
+        self.dataset = _maker("default")
+        (
+            self.train_dataset,
+            self.validation_dataset,
+            self.extra_validation_datasets,
+            self.predict_dataset,
+        ) = return_subsets(self.dataset)
+
+
+datamodule = DefaultModule
-- 
GitLab