From bcc1e440d967fffc36d88b9038001415bb1656a9 Mon Sep 17 00:00:00 2001
From: mdelitroz <maxime.delitroz@idiap.ch>
Date: Wed, 2 Aug 2023 16:21:37 +0200
Subject: [PATCH] updated TB-POC dataset and corresponding tests

---
 src/ptbench/data/tbpoc/__init__.py   |  83 -------------
 src/ptbench/data/tbpoc/datamodule.py | 138 +++++++++++++++++++++
 src/ptbench/data/tbpoc/fold_0.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_1.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_2.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_3.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_4.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_5.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_6.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_7.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_8.py     |  46 ++-----
 src/ptbench/data/tbpoc/fold_9.py     |  46 ++-----
 tests/test_tbpoc.py                  | 174 +++++++++++++++------------
 13 files changed, 345 insertions(+), 510 deletions(-)
 create mode 100644 src/ptbench/data/tbpoc/datamodule.py

diff --git a/src/ptbench/data/tbpoc/__init__.py b/src/ptbench/data/tbpoc/__init__.py
index 00f5f42c..e69de29b 100644
--- a/src/ptbench/data/tbpoc/__init__.py
+++ b/src/ptbench/data/tbpoc/__init__.py
@@ -1,83 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""TB-POC dataset for computer-aided diagnosis.
-
-* Reference: [TB-POC-2018]_
-* Original resolution (height x width or width x height): 2048 x 2500
-* Split reference: none
-* Stratified kfold protocol:
-
-  * Training samples: 72% of TB and healthy CXR (including labels)
-  * Validation samples: 18% of TB and healthy CXR (including labels)
-  * Test samples: 10% of TB and healthy CXR (including labels)
-"""
-
-import importlib.resources
-import os
-
-from ...utils.rc import load_rc
-from .. import make_dataset
-from ..dataset import JSONDataset
-from ..loader import load_pil_grayscale, make_delayed
-
-_protocols = [
-    importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_1.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_2.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_3.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_4.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_5.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_6.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_7.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
-]
-
-_datadir = load_rc().get("datadir.tbpoc", os.path.realpath(os.curdir))
-
-
-def _raw_data_loader(sample):
-    return dict(
-        data=load_pil_grayscale(os.path.join(_datadir, sample["data"])),
-        label=sample["label"],
-    )
-
-
-def _loader(context, sample):
-    # "context" is ignored in this case - database is homogeneous
-    # we returned delayed samples to avoid loading all images at once
-    return make_delayed(sample, _raw_data_loader)
-
-
-json_dataset = JSONDataset(
-    protocols=_protocols,
-    fieldnames=("data", "label"),
-    loader=_loader,
-)
-"""TB-POC dataset object."""
-
-
-def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
-    from torchvision import transforms
-
-    from ..augmentations import ElasticDeformation
-    from ..image_utils import RemoveBlackBorders
-
-    post_transforms = []
-    if RGB:
-        post_transforms = [
-            transforms.Lambda(lambda x: x.convert("RGB")),
-            transforms.ToTensor(),
-        ]
-
-    return make_dataset(
-        [json_dataset.subsets(protocol)],
-        [
-            RemoveBlackBorders(),
-            transforms.Resize(resize_size),
-            transforms.CenterCrop(cc_size),
-        ],
-        [ElasticDeformation(p=0.8)],
-        post_transforms,
-    )
diff --git a/src/ptbench/data/tbpoc/datamodule.py b/src/ptbench/data/tbpoc/datamodule.py
new file mode 100644
index 00000000..35465bac
--- /dev/null
+++ b/src/ptbench/data/tbpoc/datamodule.py
@@ -0,0 +1,138 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import importlib.resources
+import os
+
+import PIL.Image
+
+from torchvision.transforms.functional import center_crop, to_tensor
+
+from ...utils.rc import load_rc
+from ..datamodule import CachingDataModule
+from ..image_utils import load_pil_grayscale, remove_black_borders
+from ..split import JSONDatabaseSplit
+from ..typing import DatabaseSplit
+from ..typing import RawDataLoader as _BaseRawDataLoader
+from ..typing import Sample
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the Shenzen dataset.
+
+    Attributes
+    ----------
+
+    datadir
+        This variable contains the base directory where the database raw data
+        is stored.
+
+    transform
+        Transforms that are always applied to the loaded raw images.
+    """
+
+    datadir: str
+
+    def __init__(self, config_variable: str = "datadir.tbpoc"):
+        self.datadir = load_rc().get(
+            config_variable, os.path.realpath(os.curdir)
+        )
+
+    def sample(self, sample: tuple[str, int]) -> Sample:
+        """Loads a single image sample from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        sample
+            The sample representation
+        """
+        image = load_pil_grayscale(os.path.join(self.datadir, sample[0]))
+        image = remove_black_borders(image)
+        tensor = to_tensor(image)
+        tensor = center_crop(tensor, min(*tensor.shape[1:]))
+
+        # use the code below to view generated images
+        # from torchvision.transforms.functional import to_pil_image
+        # to_pil_image(tensor).show()
+        # __import__("pdb").set_trace()
+
+        return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, int]) -> int:
+        """Loads a single image sample label from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        label
+            The integer label associated with the sample
+        """
+        return sample[1]
+
+
+def make_split(basename: str) -> DatabaseSplit:
+    """Returns a database split for the TB-POC database."""
+
+    return JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
+    )
+
+
+class DataModule(CachingDataModule):
+    """TB-POC dataset for computer-aided diagnosis.
+
+    * Database reference: [TB-POC-2018]_
+    * Original resolution (height x width or width x height): 2048 x 2500 pixels 
+    or 2500 x 2048 pixels
+
+    Data specifications:
+
+    * Raw data input (on disk):
+
+        * jpeg 8-bit grayscale images
+        * resolution: fixed to one of the cases above
+
+    * Output image:
+
+        * Transforms:
+
+            * Load raw jpeg with :py:mod:`PIL`
+            * Remove black borders
+            * Convert to torch tensor
+            * Torch center cropping to get square image
+
+        * Final specifications:
+
+            * Grayscale, encoded as a single plane tensor, 32-bit floats,
+              square with varying resolutions, depending on black borders' sizes
+              on the input image
+            * Labels: 0 (healthy), 1 (active tuberculosis)
+    """
+
+    def __init__(self, split_filename: str):
+        super().__init__(
+            database_split=make_split(split_filename),
+            raw_data_loader=RawDataLoader(),
+        )
+
+
diff --git a/src/ptbench/data/tbpoc/fold_0.py b/src/ptbench/data/tbpoc/fold_0.py
index 7a423deb..972e7188 100644
--- a/src/ptbench/data/tbpoc/fold_0.py
+++ b/src/ptbench/data/tbpoc/fold_0.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 0)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-0")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-0.json")
diff --git a/src/ptbench/data/tbpoc/fold_1.py b/src/ptbench/data/tbpoc/fold_1.py
index cb4c59ba..79b9bfca 100644
--- a/src/ptbench/data/tbpoc/fold_1.py
+++ b/src/ptbench/data/tbpoc/fold_1.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 1)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-1")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-1.json")
diff --git a/src/ptbench/data/tbpoc/fold_2.py b/src/ptbench/data/tbpoc/fold_2.py
index 1bffecea..9d41fb59 100644
--- a/src/ptbench/data/tbpoc/fold_2.py
+++ b/src/ptbench/data/tbpoc/fold_2.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 2)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-2")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-2.json")
diff --git a/src/ptbench/data/tbpoc/fold_3.py b/src/ptbench/data/tbpoc/fold_3.py
index 1263d39b..08672b3f 100644
--- a/src/ptbench/data/tbpoc/fold_3.py
+++ b/src/ptbench/data/tbpoc/fold_3.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 3)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-3")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-3.json")
diff --git a/src/ptbench/data/tbpoc/fold_4.py b/src/ptbench/data/tbpoc/fold_4.py
index 119adfa9..8354a4c2 100644
--- a/src/ptbench/data/tbpoc/fold_4.py
+++ b/src/ptbench/data/tbpoc/fold_4.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 4)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-4")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-4.json")
diff --git a/src/ptbench/data/tbpoc/fold_5.py b/src/ptbench/data/tbpoc/fold_5.py
index 2a90cbdb..cb7f9561 100644
--- a/src/ptbench/data/tbpoc/fold_5.py
+++ b/src/ptbench/data/tbpoc/fold_5.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 5)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-5")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-5.json")
diff --git a/src/ptbench/data/tbpoc/fold_6.py b/src/ptbench/data/tbpoc/fold_6.py
index 42ed763d..379211aa 100644
--- a/src/ptbench/data/tbpoc/fold_6.py
+++ b/src/ptbench/data/tbpoc/fold_6.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 6)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-6")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-6.json")
diff --git a/src/ptbench/data/tbpoc/fold_7.py b/src/ptbench/data/tbpoc/fold_7.py
index ad7dbe14..b846b88a 100644
--- a/src/ptbench/data/tbpoc/fold_7.py
+++ b/src/ptbench/data/tbpoc/fold_7.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 7)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-7")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-7.json")
diff --git a/src/ptbench/data/tbpoc/fold_8.py b/src/ptbench/data/tbpoc/fold_8.py
index 4bcea788..acfd4296 100644
--- a/src/ptbench/data/tbpoc/fold_8.py
+++ b/src/ptbench/data/tbpoc/fold_8.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 8)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-8")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-8.json")
diff --git a/src/ptbench/data/tbpoc/fold_9.py b/src/ptbench/data/tbpoc/fold_9.py
index c33eb6bd..4634068e 100644
--- a/src/ptbench/data/tbpoc/fold_9.py
+++ b/src/ptbench/data/tbpoc/fold_9.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 9)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-9")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-9.json")
diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py
index 9609ea66..ee34d8d0 100644
--- a/tests/test_tbpoc.py
+++ b/tests/test_tbpoc.py
@@ -4,106 +4,126 @@
 """Tests for TB-POC dataset."""
 
 import pytest
+import torch
 
-dataset = None
+from ptbench.data.tbpoc.datamodule import make_split
 
 
-@pytest.mark.skip(reason="Test need to be updated")
-def test_protocol_consistency():
-    # Cross-validation fold 0-6
-    for f in range(7):
-        subset = dataset.subsets("fold_" + str(f))
-        assert len(subset) == 3
+def _check_split(
+    split_filename: str,
+    lengths: dict[str, int],
+    prefix: str = "TBPOC_CXR/",
+    extension: str = ".jpeg",
+    possible_labels: list[int] = [0, 1],
+):
+    """Runs a simple consistence check on the data split.
 
-        assert "train" in subset
-        assert len(subset["train"]) == 292
-        for s in subset["train"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    Parameters
+    ----------
 
-        assert "validation" in subset
-        assert len(subset["validation"]) == 74
-        for s in subset["validation"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    split_filename
+        This is the split we will check
 
-        assert "test" in subset
-        assert len(subset["test"]) == 41
-        for s in subset["test"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    lenghts
+        A dictionary that contains keys matching those of the split (this will
+        be checked).  The values of the dictionary should correspond to the
+        sizes of each of the datasets in the split.
 
-        # Check labels
-        for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+    prefix
+        Each file named in a split should start with this prefix.
 
-        for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+    extension
+        Each file named in a split should end with this extension.
 
-        for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-    # Cross-validation fold 7-9
-    for f in range(7, 10):
-        subset = dataset.subsets("fold_" + str(f))
-        assert len(subset) == 3
+    split = make_split(split_filename)
 
-        assert "train" in subset
-        assert len(subset["train"]) == 293
-        for s in subset["train"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    assert len(split) == len(lengths)
 
-        assert "validation" in subset
-        assert len(subset["validation"]) == 74
-        for s in subset["validation"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    for k in lengths.keys():
+        # dataset must have been declared
+        assert k in split
 
-        assert "test" in subset
-        assert len(subset["test"]) == 40
-        for s in subset["test"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+        assert len(split[k]) == lengths[k]
+        for s in split[k]:
+            # assert s[0].startswith(prefix)
+            assert s[0].endswith(extension)
+            assert s[1] in possible_labels
 
-        # Check labels
-        for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
 
-        for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+def _check_loaded_batch(
+    batch,
+    size: int = 1,
+    prefix: str = "TBPOC_CXR/",
+    extension: str = ".jpeg",
+    possible_labels: list[int] = [0, 1],
+):
+    """Checks the consistence of an individual (loaded) batch.
 
-        for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+    Parameters
+    ----------
 
+    batch
+        The loaded batch to be checked.
 
-@pytest.mark.skip(reason="Test need to be updated")
-@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
-def test_loading():
-    image_size_portrait = (2048, 2500)
-    image_size_landscape = (2500, 2048)
+    prefix
+        Each file named in a split should start with this prefix.
+
+    extension
+        Each file named in a split should end with this extension.
+
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-    def _check_size(size):
-        if size == image_size_portrait:
-            return True
-        elif size == image_size_landscape:
-            return True
-        return False
+    assert len(batch) == 2  # data, metadata
 
-    def _check_sample(s):
-        data = s.data
-        assert isinstance(data, dict)
-        assert len(data) == 2
+    assert isinstance(batch[0], torch.Tensor)
+    assert batch[0].shape[0] == size  # mini-batch size
+    assert batch[0].shape[1] == 1  # grayscale images
+    assert batch[0].shape[2] == batch[0].shape[3]  # image is square
 
-        assert "data" in data
-        assert _check_size(data["data"].size)  # Check size
-        assert data["data"].mode, "L"  # Check colors
+    assert isinstance(batch[1], dict)  # metadata
+    assert len(batch[1]) == 2  # label and name
 
-        assert "label" in data
-        assert data["label"] in [0, 1]  # Check labels
+    assert "label" in batch[1]
+    assert all([k in possible_labels for k in batch[1]["label"]])
 
-    limit = 30  # use this to limit testing to first images only, else None
+    assert "name" in batch[1]
+    # assert all([k.startswith(prefix) for k in batch[1]["name"]])
+    assert all([k.endswith(extension) for k in batch[1]["name"]])
 
-    subset = dataset.subsets("fold_0")
-    for s in subset["train"][:limit]:
-        _check_sample(s)
+
+def test_protocol_consistency():
+    # Cross-validation fold 0-6
+    for k in range(7):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=292, validation=74, test=41),
+        )
+
+    # Cross-validation fold 7-9
+    for k in range(7, 10):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=293, validation=74, test=40),
+        )
 
 
-@pytest.mark.skip(reason="Test need to be updated")
-@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
-def test_check():
-    assert dataset.check() == 0
+@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
+def test_loading():
+    from ptbench.data.tbpoc.fold_0 import datamodule
+
+    datamodule.model_transforms = []  # should be done before setup()
+    datamodule.setup("predict")  # sets up all datasets
+
+    for loader in datamodule.predict_dataloader().values():
+        limit = 5  # limit load checking
+        for batch in loader:
+            if limit == 0:
+                break
+            _check_loaded_batch(batch)
+            limit -= 1
-- 
GitLab