diff --git a/src/ptbench/data/tbpoc/__init__.py b/src/ptbench/data/tbpoc/__init__.py
index 00f5f42c7f81b93c5ab4c037d289275e3477b00b..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/src/ptbench/data/tbpoc/__init__.py
+++ b/src/ptbench/data/tbpoc/__init__.py
@@ -1,83 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-"""TB-POC dataset for computer-aided diagnosis.
-
-* Reference: [TB-POC-2018]_
-* Original resolution (height x width or width x height): 2048 x 2500
-* Split reference: none
-* Stratified kfold protocol:
-
-  * Training samples: 72% of TB and healthy CXR (including labels)
-  * Validation samples: 18% of TB and healthy CXR (including labels)
-  * Test samples: 10% of TB and healthy CXR (including labels)
-"""
-
-import importlib.resources
-import os
-
-from ...utils.rc import load_rc
-from .. import make_dataset
-from ..dataset import JSONDataset
-from ..loader import load_pil_grayscale, make_delayed
-
-_protocols = [
-    importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_1.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_2.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_3.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_4.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_5.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_6.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_7.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
-    importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
-]
-
-_datadir = load_rc().get("datadir.tbpoc", os.path.realpath(os.curdir))
-
-
-def _raw_data_loader(sample):
-    return dict(
-        data=load_pil_grayscale(os.path.join(_datadir, sample["data"])),
-        label=sample["label"],
-    )
-
-
-def _loader(context, sample):
-    # "context" is ignored in this case - database is homogeneous
-    # we returned delayed samples to avoid loading all images at once
-    return make_delayed(sample, _raw_data_loader)
-
-
-json_dataset = JSONDataset(
-    protocols=_protocols,
-    fieldnames=("data", "label"),
-    loader=_loader,
-)
-"""TB-POC dataset object."""
-
-
-def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
-    from torchvision import transforms
-
-    from ..augmentations import ElasticDeformation
-    from ..image_utils import RemoveBlackBorders
-
-    post_transforms = []
-    if RGB:
-        post_transforms = [
-            transforms.Lambda(lambda x: x.convert("RGB")),
-            transforms.ToTensor(),
-        ]
-
-    return make_dataset(
-        [json_dataset.subsets(protocol)],
-        [
-            RemoveBlackBorders(),
-            transforms.Resize(resize_size),
-            transforms.CenterCrop(cc_size),
-        ],
-        [ElasticDeformation(p=0.8)],
-        post_transforms,
-    )
diff --git a/src/ptbench/data/tbpoc/datamodule.py b/src/ptbench/data/tbpoc/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..35465bac49ae78db05b769a8d63aee40156f1482
--- /dev/null
+++ b/src/ptbench/data/tbpoc/datamodule.py
@@ -0,0 +1,138 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import importlib.resources
+import os
+
+import PIL.Image
+
+from torchvision.transforms.functional import center_crop, to_tensor
+
+from ...utils.rc import load_rc
+from ..datamodule import CachingDataModule
+from ..image_utils import load_pil_grayscale, remove_black_borders
+from ..split import JSONDatabaseSplit
+from ..typing import DatabaseSplit
+from ..typing import RawDataLoader as _BaseRawDataLoader
+from ..typing import Sample
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the Shenzen dataset.
+
+    Attributes
+    ----------
+
+    datadir
+        This variable contains the base directory where the database raw data
+        is stored.
+
+    transform
+        Transforms that are always applied to the loaded raw images.
+    """
+
+    datadir: str
+
+    def __init__(self, config_variable: str = "datadir.tbpoc"):
+        self.datadir = load_rc().get(
+            config_variable, os.path.realpath(os.curdir)
+        )
+
+    def sample(self, sample: tuple[str, int]) -> Sample:
+        """Loads a single image sample from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        sample
+            The sample representation
+        """
+        image = load_pil_grayscale(os.path.join(self.datadir, sample[0]))
+        image = remove_black_borders(image)
+        tensor = to_tensor(image)
+        tensor = center_crop(tensor, min(*tensor.shape[1:]))
+
+        # use the code below to view generated images
+        # from torchvision.transforms.functional import to_pil_image
+        # to_pil_image(tensor).show()
+        # __import__("pdb").set_trace()
+
+        return tensor, dict(label=sample[1], name=sample[0])  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, int]) -> int:
+        """Loads a single image sample label from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        label
+            The integer label associated with the sample
+        """
+        return sample[1]
+
+
+def make_split(basename: str) -> DatabaseSplit:
+    """Returns a database split for the TB-POC database."""
+
+    return JSONDatabaseSplit(
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(basename)
+    )
+
+
+class DataModule(CachingDataModule):
+    """TB-POC dataset for computer-aided diagnosis.
+
+    * Database reference: [TB-POC-2018]_
+    * Original resolution (height x width or width x height): 2048 x 2500 pixels 
+    or 2500 x 2048 pixels
+
+    Data specifications:
+
+    * Raw data input (on disk):
+
+        * jpeg 8-bit grayscale images
+        * resolution: fixed to one of the cases above
+
+    * Output image:
+
+        * Transforms:
+
+            * Load raw jpeg with :py:mod:`PIL`
+            * Remove black borders
+            * Convert to torch tensor
+            * Torch center cropping to get square image
+
+        * Final specifications:
+
+            * Grayscale, encoded as a single plane tensor, 32-bit floats,
+              square with varying resolutions, depending on black borders' sizes
+              on the input image
+            * Labels: 0 (healthy), 1 (active tuberculosis)
+    """
+
+    def __init__(self, split_filename: str):
+        super().__init__(
+            database_split=make_split(split_filename),
+            raw_data_loader=RawDataLoader(),
+        )
+
+
diff --git a/src/ptbench/data/tbpoc/fold_0.py b/src/ptbench/data/tbpoc/fold_0.py
index 7a423deba65dfd1bc9c1fcfc5750e0cec1bd0563..972e7188f13a0b7e67b3581eb87c0d20acd38794 100644
--- a/src/ptbench/data/tbpoc/fold_0.py
+++ b/src/ptbench/data/tbpoc/fold_0.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 0)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-0")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-0.json")
diff --git a/src/ptbench/data/tbpoc/fold_1.py b/src/ptbench/data/tbpoc/fold_1.py
index cb4c59baba491d1e8edf666c71e373b7dbf6a5ec..79b9bfcaec144157770c3be12705f73fcb0f5c79 100644
--- a/src/ptbench/data/tbpoc/fold_1.py
+++ b/src/ptbench/data/tbpoc/fold_1.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 1)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-1")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-1.json")
diff --git a/src/ptbench/data/tbpoc/fold_2.py b/src/ptbench/data/tbpoc/fold_2.py
index 1bffecea5721496cc24731147d5b0d62db6c93d2..9d41fb595637dce8f31944b4aa2eeee2bd60e58d 100644
--- a/src/ptbench/data/tbpoc/fold_2.py
+++ b/src/ptbench/data/tbpoc/fold_2.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 2)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-2")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-2.json")
diff --git a/src/ptbench/data/tbpoc/fold_3.py b/src/ptbench/data/tbpoc/fold_3.py
index 1263d39b123c002d065c32a5de17d753a3d14121..08672b3f325a8e60e19b8254ea89bc59fc3ad78b 100644
--- a/src/ptbench/data/tbpoc/fold_3.py
+++ b/src/ptbench/data/tbpoc/fold_3.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 3)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-3")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-3.json")
diff --git a/src/ptbench/data/tbpoc/fold_4.py b/src/ptbench/data/tbpoc/fold_4.py
index 119adfa98e94ed165232b73f49e7998359376cf0..8354a4c2d7038c35620a7220afa7e9a8731d44fd 100644
--- a/src/ptbench/data/tbpoc/fold_4.py
+++ b/src/ptbench/data/tbpoc/fold_4.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 4)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-4")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-4.json")
diff --git a/src/ptbench/data/tbpoc/fold_5.py b/src/ptbench/data/tbpoc/fold_5.py
index 2a90cbdb4a706435f8f8589d863132a07191ab5c..cb7f95612e23dca6af8d3d06dfaf6ae76319ed6f 100644
--- a/src/ptbench/data/tbpoc/fold_5.py
+++ b/src/ptbench/data/tbpoc/fold_5.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 5)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-5")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-5.json")
diff --git a/src/ptbench/data/tbpoc/fold_6.py b/src/ptbench/data/tbpoc/fold_6.py
index 42ed763dde698b3ce98f350068142539d07532b0..379211aad631cf9beac280d598f52beb6746eac0 100644
--- a/src/ptbench/data/tbpoc/fold_6.py
+++ b/src/ptbench/data/tbpoc/fold_6.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 6)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-6")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-6.json")
diff --git a/src/ptbench/data/tbpoc/fold_7.py b/src/ptbench/data/tbpoc/fold_7.py
index ad7dbe14e752f8f3015c18af775064f051515a34..b846b88af5cf7375f578ee2ffbc24055a4a3ff85 100644
--- a/src/ptbench/data/tbpoc/fold_7.py
+++ b/src/ptbench/data/tbpoc/fold_7.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 7)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-7")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-7.json")
diff --git a/src/ptbench/data/tbpoc/fold_8.py b/src/ptbench/data/tbpoc/fold_8.py
index 4bcea788633542d7bc9f52ef35966542891e5f91..acfd42964fe21cf15c1d47a5bc5df794fbcba961 100644
--- a/src/ptbench/data/tbpoc/fold_8.py
+++ b/src/ptbench/data/tbpoc/fold_8.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 8)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-8")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-8.json")
diff --git a/src/ptbench/data/tbpoc/fold_9.py b/src/ptbench/data/tbpoc/fold_9.py
index c33eb6bd3d2dbe8635e2b3cc2a2f7e81e859e180..4634068e5942bf9d7062876ee2007702083de1ed 100644
--- a/src/ptbench/data/tbpoc/fold_9.py
+++ b/src/ptbench/data/tbpoc/fold_9.py
@@ -1,45 +1,21 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+
 """TB-POC dataset for TB detection (cross validation fold 9)
 
 * Split reference: none (stratified kfolding)
-* This configuration resolution: 512 x 512 (default)
-* See :py:mod:`ptbench.data.tbpoc` for dataset details
-"""
-
-from clapper.logging import setup
-
-from .. import return_subsets
-from ..base_datamodule import BaseDataModule
-from . import _maker
 
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+* Stratified kfold protocol:
+    * Training samples: 72% of TB and healthy CXR (including labels)
+    * Validation samples: 18% of TB and healthy CXR (including labels)
+    * Test samples: 10% of TB and healthy CXR (including labels)
 
+* This configuration resolution: varying depending of black borders on original
+  image
+* See :py:mod:`ptbench.data.tbpoc` for dataset details
+"""
 
-class Fold0Module(BaseDataModule):
-    def __init__(
-        self,
-        train_batch_size=1,
-        predict_batch_size=1,
-        drop_incomplete_batch=False,
-        multiproc_kwargs=None,
-    ):
-        super().__init__(
-            train_batch_size=train_batch_size,
-            predict_batch_size=predict_batch_size,
-            drop_incomplete_batch=drop_incomplete_batch,
-            multiproc_kwargs=multiproc_kwargs,
-        )
-
-    def setup(self, stage: str):
-        self.dataset = _maker("fold-9")
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.dataset)
-
+from .datamodule import DataModule
 
-datamodule = Fold0Module
+datamodule = DataModule("fold-9.json")
diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py
index 9609ea66be7f59a655f744cba8b7f94e8cafb382..ee34d8d09d14379dfc08dcc4adebc33ffedd18b9 100644
--- a/tests/test_tbpoc.py
+++ b/tests/test_tbpoc.py
@@ -4,106 +4,126 @@
 """Tests for TB-POC dataset."""
 
 import pytest
+import torch
 
-dataset = None
+from ptbench.data.tbpoc.datamodule import make_split
 
 
-@pytest.mark.skip(reason="Test need to be updated")
-def test_protocol_consistency():
-    # Cross-validation fold 0-6
-    for f in range(7):
-        subset = dataset.subsets("fold_" + str(f))
-        assert len(subset) == 3
+def _check_split(
+    split_filename: str,
+    lengths: dict[str, int],
+    prefix: str = "TBPOC_CXR/",
+    extension: str = ".jpeg",
+    possible_labels: list[int] = [0, 1],
+):
+    """Runs a simple consistence check on the data split.
 
-        assert "train" in subset
-        assert len(subset["train"]) == 292
-        for s in subset["train"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    Parameters
+    ----------
 
-        assert "validation" in subset
-        assert len(subset["validation"]) == 74
-        for s in subset["validation"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    split_filename
+        This is the split we will check
 
-        assert "test" in subset
-        assert len(subset["test"]) == 41
-        for s in subset["test"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    lenghts
+        A dictionary that contains keys matching those of the split (this will
+        be checked).  The values of the dictionary should correspond to the
+        sizes of each of the datasets in the split.
 
-        # Check labels
-        for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+    prefix
+        Each file named in a split should start with this prefix.
 
-        for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+    extension
+        Each file named in a split should end with this extension.
 
-        for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-    # Cross-validation fold 7-9
-    for f in range(7, 10):
-        subset = dataset.subsets("fold_" + str(f))
-        assert len(subset) == 3
+    split = make_split(split_filename)
 
-        assert "train" in subset
-        assert len(subset["train"]) == 293
-        for s in subset["train"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    assert len(split) == len(lengths)
 
-        assert "validation" in subset
-        assert len(subset["validation"]) == 74
-        for s in subset["validation"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+    for k in lengths.keys():
+        # dataset must have been declared
+        assert k in split
 
-        assert "test" in subset
-        assert len(subset["test"]) == 40
-        for s in subset["test"]:
-            assert s.key.upper().startswith("TBPOC_CXR/TBPOC-")
+        assert len(split[k]) == lengths[k]
+        for s in split[k]:
+            # assert s[0].startswith(prefix)
+            assert s[0].endswith(extension)
+            assert s[1] in possible_labels
 
-        # Check labels
-        for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
 
-        for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+def _check_loaded_batch(
+    batch,
+    size: int = 1,
+    prefix: str = "TBPOC_CXR/",
+    extension: str = ".jpeg",
+    possible_labels: list[int] = [0, 1],
+):
+    """Checks the consistence of an individual (loaded) batch.
 
-        for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+    Parameters
+    ----------
 
+    batch
+        The loaded batch to be checked.
 
-@pytest.mark.skip(reason="Test need to be updated")
-@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
-def test_loading():
-    image_size_portrait = (2048, 2500)
-    image_size_landscape = (2500, 2048)
+    prefix
+        Each file named in a split should start with this prefix.
+
+    extension
+        Each file named in a split should end with this extension.
+
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-    def _check_size(size):
-        if size == image_size_portrait:
-            return True
-        elif size == image_size_landscape:
-            return True
-        return False
+    assert len(batch) == 2  # data, metadata
 
-    def _check_sample(s):
-        data = s.data
-        assert isinstance(data, dict)
-        assert len(data) == 2
+    assert isinstance(batch[0], torch.Tensor)
+    assert batch[0].shape[0] == size  # mini-batch size
+    assert batch[0].shape[1] == 1  # grayscale images
+    assert batch[0].shape[2] == batch[0].shape[3]  # image is square
 
-        assert "data" in data
-        assert _check_size(data["data"].size)  # Check size
-        assert data["data"].mode, "L"  # Check colors
+    assert isinstance(batch[1], dict)  # metadata
+    assert len(batch[1]) == 2  # label and name
 
-        assert "label" in data
-        assert data["label"] in [0, 1]  # Check labels
+    assert "label" in batch[1]
+    assert all([k in possible_labels for k in batch[1]["label"]])
 
-    limit = 30  # use this to limit testing to first images only, else None
+    assert "name" in batch[1]
+    # assert all([k.startswith(prefix) for k in batch[1]["name"]])
+    assert all([k.endswith(extension) for k in batch[1]["name"]])
 
-    subset = dataset.subsets("fold_0")
-    for s in subset["train"][:limit]:
-        _check_sample(s)
+
+def test_protocol_consistency():
+    # Cross-validation fold 0-6
+    for k in range(7):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=292, validation=74, test=41),
+        )
+
+    # Cross-validation fold 7-9
+    for k in range(7, 10):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=293, validation=74, test=40),
+        )
 
 
-@pytest.mark.skip(reason="Test need to be updated")
-@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
-def test_check():
-    assert dataset.check() == 0
+@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
+def test_loading():
+    from ptbench.data.tbpoc.fold_0 import datamodule
+
+    datamodule.model_transforms = []  # should be done before setup()
+    datamodule.setup("predict")  # sets up all datasets
+
+    for loader in datamodule.predict_dataloader().values():
+        limit = 5  # limit load checking
+        for batch in loader:
+            if limit == 0:
+                break
+            _check_loaded_batch(batch)
+            limit -= 1