diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 2cbc4b848d080de2d4500bd84111efb8905efcfa..9b5eab61b988e65d5b7d199d1cfcfc854cd784be 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -106,7 +106,7 @@ class _DelayedLoadingDataset(Dataset):
         sample_size_mb = _sample_size_bytes(first_sample) / (1024.0 * 1024.0)
         logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
 
-    def labels(self) -> list[int]:
+    def labels(self) -> list[int | list[int]]:
         """Returns the integer labels for all samples in the dataset."""
         return [self.loader.label(k) for k in self.raw_dataset]
 
@@ -223,7 +223,7 @@ class _CachedDataset(Dataset):
             f"{sample_size_mb:.1f} / {(len(self.data)*sample_size_mb):.1f} Mb"
         )
 
-    def labels(self) -> list[int]:
+    def labels(self) -> list[int | list[int]]:
         """Returns the integer labels for all samples in the dataset."""
         return [k[1]["label"] for k in self.data]
 
@@ -256,7 +256,7 @@ class _ConcatDataset(Dataset):
             for j in range(len(datasets[i]))
         ]
 
-    def labels(self) -> list[int]:
+    def labels(self) -> list[int | list[int]]:
         """Returns the integer labels for all samples in the dataset."""
         return list(itertools.chain(*[k.labels() for k in self._datasets]))
 
@@ -379,11 +379,11 @@ def _make_balanced_random_sampler(
                 for ds in dataset.datasets
                 for k in typing.cast(Dataset, ds).labels()
             ]
-            weights = _calculate_weights(targets)
+            weights = _calculate_weights(targets)  # type: ignore
         else:
             logger.warning(
                 f"Balancing samples **and** concatenated-datasets "
-                f"WITHOUT metadata targets (`{target}` not available)"
+                f"by using dataset totals as `{target}: int` is not true"
             )
             weights = [
                 k
@@ -403,10 +403,11 @@ def _make_balanced_random_sampler(
                 f"Balancing samples from dataset using metadata "
                 f"targets `{target}`"
             )
-            weights = _calculate_weights(dataset.labels())
+            weights = _calculate_weights(dataset.labels())  # type: ignore
         else:
             raise RuntimeError(
-                f"Cannot balance samples without metadata targets `{target}`"
+                f"Cannot balance samples with multiple class labels "
+                f"({target}: list[int]) or without metadata targets `{target}`"
             )
 
     return torch.utils.data.WeightedRandomSampler(
diff --git a/src/ptbench/data/nih_cxr14_re/__init__.py b/src/ptbench/data/nih_cxr14_re/__init__.py
index 27d1903c5a25a1ccc99520867ad407b3186a3694..b9954cf126eae1670c87296ad86f0ca6f4f9e758 100644
--- a/src/ptbench/data/nih_cxr14_re/__init__.py
+++ b/src/ptbench/data/nih_cxr14_re/__init__.py
@@ -1,7 +1,6 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-
 """NIH CXR14 (relabeled) dataset for computer-aided diagnosis.
 
 This dataset was extracted from the clinical PACS database at the National
diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json
new file mode 100644
index 0000000000000000000000000000000000000000..b9af6ad7b85245631f3ed5825f74a8e1ce2654d5
--- /dev/null
+++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json
@@ -0,0 +1,86 @@
+{
+  "train": [
+    ["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],
+    ["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]],
+    ["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],
+    ["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
+  ],
+  "validation": [
+    ["images/00000001_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000001_001.png", [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000001_002.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000007_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000010_000.png", [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000011_001.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000011_003.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
+    ["images/00000013_011.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_014.png", [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_018.png", [1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_022.png", [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_024.png", [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_025.png", [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000013_026.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_027.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_028.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_029.png", [1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]],
+    ["images/00000013_030.png", [1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_031.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_032.png", [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_034.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_037.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_038.png", [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
+    ["images/00000013_040.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_041.png", [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
+    ["images/00000013_043.png", [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000013_044.png", [1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000013_045.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],
+    ["images/00000013_046.png", [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000031_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000033_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]],
+    ["images/00000044_002.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000045_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000046_000.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]],
+    ["images/00000054_003.png", [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000059_000.png", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
+    ["images/00000066_000.png", [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],
+    ["images/00000069_000.png", [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
+  ]
+}
diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 b/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2
deleted file mode 100644
index 13b6d810cd7bb430b332476e363970a5a728fa3e..0000000000000000000000000000000000000000
Binary files a/src/ptbench/data/nih_cxr14_re/cardiomegaly.json.bz2 and /dev/null differ
diff --git a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py
index 1904ebfa60dade4ff59f770da7f1310a099c798b..0715650d7ec5b9394249b885172f2f5af646dd2d 100644
--- a/src/ptbench/data/nih_cxr14_re/cardiomegaly.py
+++ b/src/ptbench/data/nih_cxr14_re/cardiomegaly.py
@@ -2,47 +2,6 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""NIH CXR14 dataset for computer-aided diagnosis.
+from .datamodule import DataModule
 
-First 40 images with cardiomegaly.
-
-* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("cardiomegaly")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = Fold0Module
+datamodule = DataModule("cardiomegaly.json")
diff --git a/src/ptbench/data/nih_cxr14_re/datamodule.py b/src/ptbench/data/nih_cxr14_re/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..66a4379cf0880817c8035ff62cd3fea1d28af193
--- /dev/null
+++ b/src/ptbench/data/nih_cxr14_re/datamodule.py
@@ -0,0 +1,157 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import importlib.resources
+import os
+
+import PIL.Image
+
+from torchvision.transforms.functional import to_tensor
+
+from ...utils.rc import load_rc
+from ..datamodule import CachingDataModule
+from ..split import JSONDatabaseSplit
+from ..typing import DatabaseSplit
+from ..typing import RawDataLoader as _BaseRawDataLoader
+from ..typing import Sample
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the Montgomery dataset.
+
+    Attributes
+    ----------
+
+    datadir
+        This variable contains the base directory where the database raw data
+        is stored.
+
+    idiap_file_organisation
+        This variable will be ``True``, if the user has set the configuration
+        parameter ``nih_cxr14_re.idiap_file_organisation`` in the global
+        configuration file.  It will cause internal loader to search for files
+        in a slightly different folder structure, that was adapted to Idiap's
+        requirements (number of files per folder to be less than 10k).
+    """
+
+    datadir: str
+    idiap_file_organisation: bool
+
+    def __init__(self):
+        rc = load_rc()
+        self.datadir = rc.get(
+            "datadir.nih_cxr14_re", os.path.realpath(os.curdir)
+        )
+        self.idiap_file_organisation = rc.get(
+            "nih_cxr14_re.idiap_folder_structure", False
+        )
+
+    def sample(self, sample: tuple[str, list[int]]) -> Sample:
+        """Loads a single image sample from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        sample
+            The sample representation
+        """
+        file_path = sample[0]  # default
+        if self.idiap_file_organisation:
+            # for folder lookup efficiency, data is split into subfolders
+            # each original file is on the subfolder `f[:5]/f`, where f
+            # is the original file basename
+            basename = os.path.basename(sample[0])
+            file_path = os.path.join(
+                os.path.dirname(sample[0]),
+                basename[:5],
+                basename,
+            )
+
+        # N.B.: NIH CXR-14 images are encoded as color PNGs
+        image = PIL.Image.open(os.path.join(self.datadir, file_path))
+        tensor = to_tensor(image)
+
+        # use the code below to view generated images
+        # from torchvision.transforms.functional import to_pil_image
+        # to_pil_image(tensor).show()
+        # __import__("pdb").set_trace()
+
+        return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, list[int]]) -> list[int]:
+        """Loads a single image sample label from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        labels
+            The integer labels associated with the sample
+        """
+        return sample[1]
+
+
+def make_split(basename: str) -> DatabaseSplit:
+    """Returns a database split for the Montgomery database."""
+
+    return JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
+    )
+
+
+class DataModule(CachingDataModule):
+    """NIH CXR14 (relabeled) datamodule for computer-aided diagnosis.
+
+    This dataset was extracted from the clinical PACS database at the National
+    Institutes of Health Clinical Center (USA) and represents 60% of all their
+    radiographs. It contains labels for 14 common radiological signs in this
+    order: cardiomegaly, emphysema, effusion, hernia, infiltration, mass,
+    nodule, atelectasis, pneumothorax, pleural thickening, pneumonia, fibrosis,
+    edema and consolidation. This is the relabeled version created in the
+    CheXNeXt study.
+
+    * Reference: [NIH-CXR14-2017]_
+    * Original resolution (height x width): 1024 x 1024
+    * Labels: [CHEXNEXT-2018]_
+    * Split reference: [CHEXNEXT-2018]_
+    * Protocol ``default``:
+
+      * Training samples: 98637
+      * Validation samples: 6350
+      * Test samples: 4355
+
+    * Output image:
+
+        * Transforms:
+
+            * Load raw PNG with :py:mod:`PIL`
+
+        * Final specifications
+
+            * RGB, encoded as a 3-plane image, 8 bits
+            * Square (1024x1024 px)
+    """
+
+    def __init__(self, split_filename: str):
+        super().__init__(
+            database_split=make_split(split_filename),
+            raw_data_loader=RawDataLoader(),
+        )
diff --git a/src/ptbench/data/nih_cxr14_re/default.py b/src/ptbench/data/nih_cxr14_re/default.py
index 0ea6ef5acc55560ae8db115f3585b40da3cf58b8..7fe993a981c86c0161327d1ddb4498e08a90313c 100644
--- a/src/ptbench/data/nih_cxr14_re/default.py
+++ b/src/ptbench/data/nih_cxr14_re/default.py
@@ -2,46 +2,6 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-"""NIH CXR14 (relabeled) dataset for computer-aided diagnosis (default
-protocol)
+from .datamodule import DataModule
 
-* See :py:mod:`ptbench.data.nih_cxr14_re` for split details
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.nih_cxr14_re` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
-
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("default")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
-
-datamodule = DefaultModule
+datamodule = DataModule("default.json.bz2")
diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py
index bf821068eee6d6a724150a5c56e9b9ad374374da..6f41b39eb33d2a91c51008623388bc2900032665 100644
--- a/src/ptbench/data/typing.py
+++ b/src/ptbench/data/typing.py
@@ -28,7 +28,7 @@ class RawDataLoader:
         """Loads whole samples from media."""
         raise NotImplementedError("You must implement the `sample()` method")
 
-    def label(self, k: typing.Any) -> int:
+    def label(self, k: typing.Any) -> int | list[int]:
         """Loads only sample label from media.
 
         If you do not override this implementation, then, by default,
@@ -79,7 +79,7 @@ class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
     provide a dunder len method.
     """
 
-    def labels(self) -> list[int]:
+    def labels(self) -> list[int | list[int]]:
         """Returns the integer labels for all samples in the dataset."""
         raise NotImplementedError("You must implement the `labels()` method")