From 7be6a4eeb05f6da43b780e22ecd1b253b87ef694 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 1 Aug 2023 09:07:46 +0200
Subject: [PATCH] [data.nih_cxr14_re] Update datamodule; Prepare framework for
 multi-class classification

---
 src/ptbench/data/datamodule.py                |  15 +-
 src/ptbench/data/nih_cxr14_re/__init__.py     |   1 -
 .../data/nih_cxr14_re/cardiomegaly.json       |  86 ++++++++++
 .../data/nih_cxr14_re/cardiomegaly.json.bz2   | Bin 392 -> 0 bytes
 src/ptbench/data/nih_cxr14_re/cardiomegaly.py |  45 +----
 src/ptbench/data/nih_cxr14_re/datamodule.py   | 157 ++++++++++++++++++
 src/ptbench/data/nih_cxr14_re/default.py      |  44 +----
 src/ptbench/data/typing.py                    |   4 +-
 8 files changed, 257 insertions(+), 95 deletions(-)
 create mode 100644 src/ptbench/data/nih_cxr14_re/cardiomegaly.json
 delete mode 100644 src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2
 create mode 100644 src/ptbench/data/nih_cxr14_re/datamodule.py

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 2cbc4b84..9b5eab61 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -106,7 +106,7 @@ class _DelayedLoadingDataset(Dataset):
         sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
         logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
 
-    def labels(self) -> list[int]:
+    def labels(self) -> list[int | list[int]]:
         """Returns the integer labels for all samples in the dataset."""
         return [self.loader.label(k) for k in self.raw_dataset]
 
@@ -223,7 +223,7 @@ class _CachedDataset(Dataset):
             f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb"
         )
 
-    def labels(self) -> list[int]:
+    def labels(self) -> list[int | list[int]]:
         """Returns the integer labels for all samples in the dataset."""
         return [k[1]["label"] for k in self.data]
 
@@ -256,7 +256,7 @@ class _ConcatDataset(Dataset):
             for j in range(len(datasets[i]))
         ]
 
-    def labels(self) -> list[int]:
+    def labels(self) -> list[int | list[int]]:
         """Returns the integer labels for all samples in the dataset."""
         return list(itertools.chain(*[k.labels() for k in self._datasets]))
 
@@ -379,11 +379,11 @@ def _make_balanced_random_sampler(
                 for ds in dataset.datasets
                 for k in typing.cast(Dataset, ds).labels()
             ]
-            weights = _calculate_weights(targets)
+            weights = _calculate_weights(targets)  # type: ignore
         else:
             logger.warning(
                 f"Balancing samples **and** concatenated-datasets "
-                f"WITHOUT metadata targets (`{target}` not available)"
+                f"by using dataset totals as `{target}: int` is not true"
             )
             weights = [
                 k
@@ -403,10 +403,11 @@ def _make_balanced_random_sampler(
                 f"Balancing samples from dataset using metadata "
                 f"targets `{target}`"
             )
-            weights = _calculate_weights(dataset.labels())
+            weights = _calculate_weights(dataset.labels())  # type: ignore
         else:
             raise RuntimeError(
-                f"Cannot balance samples without metadata targets `{target}`"
+                f"Cannot balance samples with multiple class labels "
+                f"({target}: list[int]) or without metadata targets `{target}`"
             )
 
     return torch.utils.data.WeightedRandomSampler(
diff --git a/src/ptbench/data/nih_cxr14_re/__init__.py b/src/ptbench/data/nih_cxr14_re/__init__.py
index 27d1903c..b9954cf1 100644
--- a/src/ptbench/data/nih_cxr14_re/__init__.py
+++ b/src/ptbench/data/nih_cxr14_re/__init__.py
@@ -1,7 +1,6 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-
 """NIH CXR14 (relabeled) dataset for computer-aided diagnosis.
 
 This dataset was extracted from the clinical PACS database at the National
diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json
new file mode 100644
index 00000000..b9af6ad7
--- /dev/null
+++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json
@@ -0,0 +1,86 @@
+{
+  "train": [
+    ["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],
+    ["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]],
+    ["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],
+    ["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
+  ],
+  "validation": [
+    ["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],
+    ["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]],
+    ["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],
+    ["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
+  ]
+}
diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2
deleted file mode 100644
index 13b6d810cd7bb430b332476e363970a5a728fa3e..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 392
zcmV;30eAjFT4*^jL0KkKS=7hkqW}wZTYwl4PzC?+01Bq3-3mYgyZ`_OFqoPQ5r|}A
zFvb%TL4q+1j3yYuVrVc%A(4c@sj8Dg6g0+@X+1`n8L8kI*kjE%=Kdbn0fF;)hHx4E
z27R6ZoV#tY?|Bnx6G+JfCXzui1dS$58f`HqX_HMEG-R1HB#^{u5+qGDiKd%L5h6^I
zX|{=^Ns=^;5u$A-B1lAyB*~K*B#kDRjUy&a6KRtWnoXoc+G&y=;K}BXig~A3N6FE;
zrw&~nj;^O?3G+I=9i82ix=viakHBXko{Y`1Z18zFY}xAB>KQiRnrXB{b!I%D2L=*h
z88RUfMnn;&h#@wKlSwqwmq%v~@1ulfj;+i2b|jON+?=L6qq<4DI-5GWxO8$lq?4Du
zJ31U3I5>OVoV*?mj)!+w!Q{g}FFCd^)Z4|T+3<O~JX6H|9i5!F?1?!xeZxN<-*5NS
mGt@Kg8UK8PIQF{+es5sQ@)_6IGq-o6{}*yaI8cz($Ks>l&#)o@

diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py
index 1904ebfa..0715650d 100644
--- a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py
+++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py
@@ -2,47 +2,6 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""NIH CXR14 dataset for computer-aided diagnosis.
+from .datamodule import DataModule
 
-First 40 images with cardiomegaly.
-
-* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("cardiomegaly")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = Fold0Module
+datamodule = DataModule("cardiomegaly.json")
diff --git a/src/ptbench/data/nih_cxr14_re/datamodule.py b/src/ptbench/data/nih_cxr14_re/datamodule.py
new file mode 100644
index 00000000..66a4379c
--- /dev/null
+++ b/src/ptbench/data/nih_cxr14_re/datamodule.py
@@ -0,0 +1,157 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import importlib.resources
+import os
+
+import PIL.Image
+
+from torchvision.transforms.functional import to_tensor
+
+from ...utils.rc import load_rc
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from ..typing import DatabaseSplit
+from ..typing import RawDataLoader as _BaseRawDataLoader
+from ..typing import Sample
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the Montgomery dataset.
+
+    Attributes
+    ----------
+
+    datadir
+        This variable contains the base directory where the database raw data
+        is stored.
+
+    idiap_file_organisation
+        This variable will be ``True``, if the user has set the configuration
+        parameter ``nih_cxr14_re.idiap_file_organisation`` in the global
+        configuration file.  It will cause internal loader to search for files
+        in a slightly different folder structure, that was adapted to Idiap's
+        requirements (number of files per folder to be less than 10k).
+    """
+
+    datadir: str
+    idiap_file_organisation: bool
+
+    def __init__(self):
+        rc = load_rc()
+        self.datadir = rc.get(
+            "datadir.nih_cxr14_re", os.path.realpath(os.curdir)
+        )
+        self.idiap_file_organisation = rc.get(
+            "nih_cxr14_re.idiap_folder_structure", False
+        )
+
+    def sample(self, sample: tuple[str, list[int]]) -> Sample:
+        """Loads a single image sample from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        sample
+            The sample representation
+        """
+        file_path = sample[0]  # default
+        if self.idiap_file_organisation:
+            # for folder lookup efficiency, data is split into subfolders
+            # each original file is on the subfolder `f[:5]/f`, where f
+            # is the original file basename
+            basename = os.path.basename(sample[0])
+            file_path = os.path.join(
+                os.path.dirname(sample[0]),
+                basename[:5],
+                basename,
+            )
+
+        # N.B.: NIH CXR-14 images are encoded as color PNGs
+        image = PIL.Image.open(os.path.join(self.datadir, file_path))
+        tensor = to_tensor(image)
+
+        # use the code below to view generated images
+        # from torchvision.transforms.functional import to_pil_image
+        # to_pil_image(tensor).show()
+        # __import__("pdb").set_trace()
+
+        return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, list[int]]) -> list[int]:
+        """Loads a single image sample label from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        labels
+            The integer labels associated with the sample
+        """
+        return sample[1]
+
+
+def make_split(basename: str) -> DatabaseSplit:
+    """Returns a database split for the Montgomery database."""
+
+    return JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
+    )
+
+
+class DataModule(CachingDataModule):
+    """NIH CXR14 (relabeled) datamodule for computer-aided diagnosis.
+
+    This dataset was extracted from the clinical PACS database at the National
+    Institutes of Health Clinical Center (USA) and represents 60% of all their
+    radiographs. It contains labels for 14 common radiological signs in this
+    order: cardiomegaly, emphysema, effusion, hernia, infiltration, mass,
+    nodule, atelectasis, pneumothorax, pleural thickening, pneumonia, fibrosis,
+    edema and consolidation. This is the relabeled version created in the
+    CheXNeXt study.
+
+    * Reference: [NIH-CXR14-2017]_
+    * Original resolution (height x width): 1024 x 1024
+    * Labels: [CHEXNEXT-2018]_
+    * Split reference: [CHEXNEXT-2018]_
+    * Protocol ``default``:
+
+      * Training samples: 98637
+      * Validation samples: 6350
+      * Test samples: 4355
+
+    * Output image:
+
+        * Transforms:
+
+            * Load raw PNG with :py:mod:`PIL`
+
+        * Final specifications
+
+            * RGB, encoded as a 3-plane image, 8 bits
+            * Square (1024x1024 px)
+    """
+
+    def __init__(self, split_filename: str):
+        super().__init__(
+            database_split=make_split(split_filename),
+            raw_data_loader=RawDataLoader(),
+        )
diff --git a/src/ptbench/data/nih_cxr14_re/default.py b/src/ptbench/data/nih_cxr14_re/default.py
index 0ea6ef5a..7fe993a9 100644
--- a/src/ptbench/data/nih_cxr14_re/default.py
+++ b/src/ptbench/data/nih_cxr14_re/default.py
@@ -2,46 +2,6 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""NIH CXR14 (relabeled) dataset for computer-aided diagnosis (default
-protocol)
+from .datamodule import DataModule
 
-* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("default")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+datamodule = DataModule("default.json.bz2")
diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py
index bf821068..6f41b39e 100644
--- a/src/ptbench/data/typing.py
+++ b/src/ptbench/data/typing.py
@@ -28,7 +28,7 @@ class RawDataLoader:
         """Loads whole samples from media."""
         raise NotImplementedError("You must implement the `sample()` method")
 
-    def label(self, k: typing.Any) -> int:
+    def label(self, k: typing.Any) -> int | list[int]:
         """Loads only sample label from media.
 
         If you do not override this implementation, then, by default,
@@ -79,7 +79,7 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
     provide a dunder len method.
     """
 
-    def labels(self) -> list[int]:
+    def labels(self) -> list[int | list[int]]:
         """Returns the integer labels for all samples in the dataset."""
         raise NotImplementedError("You must implement the `labels()` method")
 
-- 
GitLab