From 75f98d0c95cb91e5333b134e40e6ad414b4c026b Mon Sep 17 00:00:00 2001
From: mdelitroz <maxime.delitroz@idiap.ch>
Date: Wed, 2 Aug 2023 15:09:48 +0200
Subject: [PATCH] updated HIV-TB dataset and related tests

---
 src/ptbench/data/hivtb/__init__.py   |  82 -------------
 src/ptbench/data/hivtb/datamodule.py | 132 ++++++++++++++++++++
 src/ptbench/data/hivtb/fold_0.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_1.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_2.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_3.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_4.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_5.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_6.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_7.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_8.py     |  44 ++-----
 src/ptbench/data/hivtb/fold_9.py     |  44 ++-----
 tests/test_hivtb.py                  | 172 +++++++++++++++------------
 13 files changed, 328 insertions(+), 498 deletions(-)
 create mode 100644 src/ptbench/data/hivtb/datamodule.py

diff --git a/src/ptbench/data/hivtb/__init__.py b/src/ptbench/data/hivtb/__init__.py
index 0fca31ae..e69de29b 100644
--- a/src/ptbench/data/hivtb/__init__.py
+++ b/src/ptbench/data/hivtb/__init__.py
@@ -1,82 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""HIV-TB dataset for computer-aided diagnosis (only BMP files)
-
-* Reference: [HIV-TB-2019]_
-* Original resolution (height x width or width x height): 2048 x 2500
-* Split reference: none
-* Stratified kfold protocol:
-
-  * Training samples: 72% of TB and healthy CXR (including labels)
-  * Validation samples: 18% of TB and healthy CXR (including labels)
-  * Test samples: 10% of TB and healthy CXR (including labels)
-"""
-
-import importlib.resources
-import os
-
-from ...utils.rc import load_rc
-from .. import make_dataset
-from ..dataset import JSONDataset
-from ..loader import load_pil_grayscale, make_delayed
-
-_protocols = [
-    importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_1.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_2.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_3.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_4.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_5.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_6.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_7.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
-]
-
-_datadir = load_rc().get("datadir.hivtb", os.path.realpath(os.curdir))
-
-
-def _raw_data_loader(sample):
-    return dict(
-        data=load_pil_grayscale(os.path.join(_datadir, sample["data"])),
-        label=sample["label"],
-    )
-
-
-def _loader(context, sample):
-    # "context" is ignored in this case - database is homogeneous
-    # we returned delayed samples to avoid loading all images at once
-    return make_delayed(sample, _raw_data_loader)
-
-
-json_dataset = JSONDataset(
-    protocols=_protocols,
-    fieldnames=("data", "label"),
-    loader=_loader,
-)
-"""HIV-TB dataset object."""
-
-
-def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
-    from torchvision import transforms
-
-    from ..augmentations import ElasticDeformation, RemoveBlackBorders
-
-    post_transforms = []
-    if RGB:
-        post_transforms = [
-            transforms.Lambda(lambda x: x.convert("RGB")),
-            transforms.ToTensor(),
-        ]
-
-    return make_dataset(
-        [json_dataset.subsets(protocol)],
-        [
-            RemoveBlackBorders(),
-            transforms.Resize(resize_size),
-            transforms.CenterCrop(cc_size),
-        ],
-        [ElasticDeformation(p=0.8)],
-        post_transforms,
-    )
diff --git a/src/ptbench/data/hivtb/datamodule.py b/src/ptbench/data/hivtb/datamodule.py
new file mode 100644
index 00000000..63075c61
--- /dev/null
+++ b/src/ptbench/data/hivtb/datamodule.py
@@ -0,0 +1,132 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import importlib.resources
+import os
+
+import PIL.Image
+
+from torchvision.transforms.functional import center_crop, to_tensor
+
+from ...utils.rc import load_rc
+from ..datamodule import CachingDataModule
+from ..image_utils import load_pil_grayscale, remove_black_borders
+from ..split import JSONDatabaseSplit
+from ..typing import DatabaseSplit
+from ..typing import RawDataLoader as _BaseRawDataLoader
+from ..typing import Sample
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the HIV-TB dataset.
+
+    Attributes
+    ----------
+
+    datadir
+        This variable contains the base directory where the database raw data
+        is stored.
+    """
+
+    datadir: str
+
+    def __init__(self):
+        self.datadir = load_rc().get(
+            "datadir.hivtb", os.path.realpath(os.curdir)
+        )
+
+    def sample(self, sample: tuple[str, int]) -> Sample:
+        """Loads a single image sample from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        sample
+            The sample representation
+        """
+        image = load_pil_grayscale(os.path.join(self.datadir, sample[0]))
+        image = remove_black_borders(image)
+        tensor = to_tensor(image)
+        tensor = center_crop(tensor, min(*tensor.shape[1:]))
+
+        # use the code below to view generated images
+        # from torchvision.transforms.functional import to_pil_image
+        # to_pil_image(tensor).show()
+        # __import__("pdb").set_trace()
+
+        return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, int]) -> int:
+        """Loads a single image sample label from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        label
+            The integer label associated with the sample
+        """
+        return sample[1]
+
+
+def make_split(basename: str) -> DatabaseSplit:
+    """Returns a database split for the HIV-TB database."""
+
+    return JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
+    )
+
+
+class DataModule(CachingDataModule):
+    """HIV-TB dataset for computer-aided diagnosis (only BMP files)
+
+    * Database reference: [HIV-TB-2019]_
+    * Original resolution (height x width or width x height): 2048 x 2500 pixels
+      or 2500 x 2048 pixels
+    
+    Data specifications:
+
+    * Raw data input (on disk):
+
+        * BMP images 8 bit grayscale
+        * resolution fixed to one of the cases above
+
+    * Output image:
+
+        * Transforms:
+
+            * Load raw BMP with :py:mod:`PIL`
+            * Remove black borders
+            * Convert to torch tensor
+            * Torch center cropping to get square image
+
+    * Final specifications
+
+        * Grayscale, encoded as a single plane tensor, 32-bit floats,
+          square at 2048 x 2048 pixels
+        * Labels: 0 (healthy), 1 (active tuberculosis)
+    """
+
+    def __init__(self, split_filename: str):
+        super().__init__(
+            database_split=make_split(split_filename),
+            raw_data_loader=RawDataLoader(),
+        )
diff --git a/src/ptbench/data/hivtb/fold_0.py b/src/ptbench/data/hivtb/fold_0.py
index e8caee65..ba9e9150 100644
--- a/src/ptbench/data/hivtb/fold_0.py
+++ b/src/ptbench/data/hivtb/fold_0.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 0)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-0")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-0.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_1.py b/src/ptbench/data/hivtb/fold_1.py
index bb12b311..84fb7581 100644
--- a/src/ptbench/data/hivtb/fold_1.py
+++ b/src/ptbench/data/hivtb/fold_1.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 1)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-1")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-1.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_2.py b/src/ptbench/data/hivtb/fold_2.py
index 7bd3703e..a5f5e97a 100644
--- a/src/ptbench/data/hivtb/fold_2.py
+++ b/src/ptbench/data/hivtb/fold_2.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 2)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-2")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-2.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_3.py b/src/ptbench/data/hivtb/fold_3.py
index cac94f67..1b643ae4 100644
--- a/src/ptbench/data/hivtb/fold_3.py
+++ b/src/ptbench/data/hivtb/fold_3.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 3)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-3")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-3.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_4.py b/src/ptbench/data/hivtb/fold_4.py
index c5952356..581eb85c 100644
--- a/src/ptbench/data/hivtb/fold_4.py
+++ b/src/ptbench/data/hivtb/fold_4.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 4)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-4")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-4.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_5.py b/src/ptbench/data/hivtb/fold_5.py
index bc80b9ff..47ae66d1 100644
--- a/src/ptbench/data/hivtb/fold_5.py
+++ b/src/ptbench/data/hivtb/fold_5.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 5)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-5")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-5.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_6.py b/src/ptbench/data/hivtb/fold_6.py
index d1a646dc..c93232f4 100644
--- a/src/ptbench/data/hivtb/fold_6.py
+++ b/src/ptbench/data/hivtb/fold_6.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 6)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-6")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-6.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_7.py b/src/ptbench/data/hivtb/fold_7.py
index de29f234..33d5cc83 100644
--- a/src/ptbench/data/hivtb/fold_7.py
+++ b/src/ptbench/data/hivtb/fold_7.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 7)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-7")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-7.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_8.py b/src/ptbench/data/hivtb/fold_8.py
index 9370dcea..91d89557 100644
--- a/src/ptbench/data/hivtb/fold_8.py
+++ b/src/ptbench/data/hivtb/fold_8.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 8)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-8")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-8.json")
 
-datamodule = DefaultModule
diff --git a/src/ptbench/data/hivtb/fold_9.py b/src/ptbench/data/hivtb/fold_9.py
index 70605f8d..0e0063e8 100644
--- a/src/ptbench/data/hivtb/fold_9.py
+++ b/src/ptbench/data/hivtb/fold_9.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """HIV-TB dataset for TB detection (cross validation fold 9)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.hivtb` for dataset details
-"""
-
-from clapper.logging import setup
 
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
-
-
-class DefaultModule(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
+* This configuration resolution: 2048 x 2048 (default)
+* See :py:mod:`ptbench.data.hivtb` for dataset details
+"""
 
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-9")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
+from .datamodule import DataModule
 
+datamodule = DataModule("fold-9.json")
 
-datamodule = DefaultModule
diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py
index 37876051..9e814138 100644
--- a/tests/test_hivtb.py
+++ b/tests/test_hivtb.py
@@ -4,106 +4,126 @@
 """Tests for HIV-TB dataset."""
 
 import pytest
+import torch
 
-dataset = None
+from ptbench.data.hivtb.datamodule import make_split
 
 
-@pytest.mark.skip(reason="Test need to be updated")
-def test_protocol_consistency():
-    # Cross-validation fold 0-2
-    for f in range(3):
-        subset = dataset.subsets("fold_" + str(f))
-        assert len(subset) == 3
+def _check_split(
+    split_filename: str,
+    lengths: dict[str, int],
+    prefix: str = "HIV-TB_Algorithm_study_X-rays/",
+    extension: str = ".BMP",
+    possible_labels: list[int] = [0, 1],
+):
+    """Runs a simple consistence check on the data split.
 
-        assert "train" in subset
-        assert len(subset["train"]) == 174
-        for s in subset["train"]:
-            assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
+    Parameters
+    ----------
 
-        assert "validation" in subset
-        assert len(subset["validation"]) == 44
-        for s in subset["validation"]:
-            assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
+    split_filename
+        This is the split we will check
 
-        assert "test" in subset
-        assert len(subset["test"]) == 25
-        for s in subset["test"]:
-            assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
+    lenghts
+        A dictionary that contains keys matching those of the split (this will
+        be checked).  The values of the dictionary should correspond to the
+        sizes of each of the datasets in the split.
 
-        # Check labels
-        for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+    prefix
+        Each file named in a split should start with this prefix.
 
-        for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+    extension
+        Each file named in a split should end with this extension.
 
-        for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-    # Cross-validation fold 3-9
-    for f in range(3, 10):
-        subset = dataset.subsets("fold_" + str(f))
-        assert len(subset) == 3
+    split = make_split(split_filename)
 
-        assert "train" in subset
-        assert len(subset["train"]) == 175
-        for s in subset["train"]:
-            assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
+    assert len(split) == len(lengths)
 
-        assert "validation" in subset
-        assert len(subset["validation"]) == 44
-        for s in subset["validation"]:
-            assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
+    for k in lengths.keys():
+        # dataset must have been declared
+        assert k in split
 
-        assert "test" in subset
-        assert len(subset["test"]) == 24
-        for s in subset["test"]:
-            assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/")
+        assert len(split[k]) == lengths[k]
+        for s in split[k]:
+            assert s[0].startswith(prefix)
+            assert s[0].endswith(extension)
+            assert s[1] in possible_labels
 
-        # Check labels
-        for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
 
-        for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+def _check_loaded_batch(
+    batch,
+    size: int = 1,
+    prefix: str = "HIV-TB_Algorithm_study_X-rays/",
+    extension: str = ".BMP",
+    possible_labels: list[int] = [0, 1],
+):
+    """Checks the consistence of an individual (loaded) batch.
 
-        for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+    Parameters
+    ----------
 
+    batch
+        The loaded batch to be checked.
 
-@pytest.mark.skip(reason="Test need to be updated")
-@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
-def test_loading():
-    image_size_portrait = (2048, 2500)
-    image_size_landscape = (2500, 2048)
+    prefix
+        Each file named in a split should start with this prefix.
+
+    extension
+        Each file named in a split should end with this extension.
 
-    def _check_size(size):
-        if size == image_size_portrait:
-            return True
-        elif size == image_size_landscape:
-            return True
-        return False
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-    def _check_sample(s):
-        data = s.data
-        assert isinstance(data, dict)
-        assert len(data) == 2
+    assert len(batch) == 2  # data, metadata
 
-        assert "data" in data
-        assert _check_size(data["data"].size)  # Check size
-        assert data["data"].mode == "L"  # Check colors
+    assert isinstance(batch[0], torch.Tensor)
+    assert batch[0].shape[0] == size  # mini-batch size
+    assert batch[0].shape[1] == 1  # grayscale images
+    assert batch[0].shape[2] == batch[0].shape[3]  # image is square
 
-        assert "label" in data
-        assert data["label"] in [0, 1]  # Check labels
+    assert isinstance(batch[1], dict)  # metadata
+    assert len(batch[1]) == 2  # label and name
 
-    limit = 30  # use this to limit testing to first images only, else None
+    assert "label" in batch[1]
+    assert all([k in possible_labels for k in batch[1]["label"]])
 
-    subset = dataset.subsets("fold_0")
-    for s in subset["train"][:limit]:
-        _check_sample(s)
+    assert "name" in batch[1]
+    assert all([k.startswith(prefix) for k in batch[1]["name"]])
+    assert all([k.endswith(extension) for k in batch[1]["name"]])
+
+
+def test_protocol_consistency():
+    # Cross-validation fold 0-2
+    for k in range(3):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=174, validation=44, test=25),
+        )
+
+    # Cross-validation fold 3-9
+    for k in range(3, 10):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=175, validation=44, test=24),
+        )
 
 
-@pytest.mark.skip(reason="Test need to be updated")
 @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
-def test_check():
-    assert dataset.check() == 0
+def test_loading():
+    from ptbench.data.hivtb.fold_0 import datamodule
+
+    datamodule.model_transforms = []  # should be done before setup()
+    datamodule.setup("predict")  # sets up all datasets
+
+    for loader in datamodule.predict_dataloader().values():
+        limit = 5  # limit load checking
+        for batch in loader:
+            if limit == 0:
+                break
+            _check_loaded_batch(batch)
+            limit -= 1
\ No newline at end of file
-- 
GitLab