diff --git a/tests/conftest.py b/tests/conftest.py
index 75b802a252994df5e2b5793bcd2c3422f5b9ad15..6e992cdd04092d04ff40a0efdb6293c8cff7dd7c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -162,7 +162,9 @@ class DatabaseCheckers:
 
             assert len(split[k]) == lengths[k]
             for s in split[k]:
-                assert any([s[0].startswith(k) for k in prefixes])
+                assert any(
+                    [s[0].startswith(k) for k in prefixes]
+                ), f"Sample with name {s[0]} does not start with any of the prefixes in {prefixes}"
                 assert s[1] in possible_labels
 
     @staticmethod
diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py
index 920f1574ccabe924c9609d1fd1533a0524fb83e4..eb6513f0046d70a7846e310a7e47a837c7222d8e 100644
--- a/tests/test_hivtb.py
+++ b/tests/test_hivtb.py
@@ -3,127 +3,89 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """Tests for HIV-TB dataset."""
 
-import pytest
-import torch
-
-from ptbench.data.hivtb.datamodule import make_split
-
-
-def _check_split(
-    split_filename: str,
-    lengths: dict[str, int],
-    prefix: str = "HIV-TB_Algorithm_study_X-rays/",
-    extension: str = ".BMP",
-    possible_labels: list[int] = [0, 1],
-):
-    """Runs a simple consistence check on the data split.
-
-    Parameters
-    ----------
-
-    split_filename
-        This is the split we will check
-
-    lenghts
-        A dictionary that contains keys matching those of the split (this will
-        be checked).  The values of the dictionary should correspond to the
-        sizes of each of the datasets in the split.
+import importlib
 
-    prefix
-        Each file named in a split should start with this prefix.
-
-    extension
-        Each file named in a split should end with this extension.
-
-    possible_labels
-        These are the list of possible labels contained in any split.
-    """
-
-    split = make_split(split_filename)
-
-    assert len(split) == len(lengths)
-
-    for k in lengths.keys():
-        # dataset must have been declared
-        assert k in split
-
-        assert len(split[k]) == lengths[k]
-        for s in split[k]:
-            assert s[0].startswith(prefix)
-            assert s[0].endswith(extension)
-            assert s[1] in possible_labels
+import pytest
 
 
-def _check_loaded_batch(
-    batch,
-    size: int = 1,
-    prefix: str = "HIV-TB_Algorithm_study_X-rays/",
-    extension: str = ".BMP",
-    possible_labels: list[int] = [0, 1],
+def id_function(val):
+    if isinstance(val, dict):
+        return str(val)
+    return repr(val)
+
+
+@pytest.mark.parametrize(
+    "split,lenghts",
+    [
+        ("fold-0", dict(train=174, validation=44, test=25)),
+        ("fold-1", dict(train=174, validation=44, test=25)),
+        ("fold-2", dict(train=174, validation=44, test=25)),
+        ("fold-3", dict(train=175, validation=44, test=24)),
+        ("fold-4", dict(train=175, validation=44, test=24)),
+        ("fold-5", dict(train=175, validation=44, test=24)),
+        ("fold-6", dict(train=175, validation=44, test=24)),
+        ("fold-7", dict(train=175, validation=44, test=24)),
+        ("fold-8", dict(train=175, validation=44, test=24)),
+        ("fold-9", dict(train=175, validation=44, test=24)),
+    ],
+    ids=id_function,  # just changes how pytest prints it
+)
+def test_protocol_consistency(
+    database_checkers, split: str, lenghts: dict[str, int]
 ):
-    """Checks the consistence of an individual (loaded) batch.
+    from ptbench.data.hivtb.datamodule import make_split
 
-    Parameters
-    ----------
-
-    batch
-        The loaded batch to be checked.
-
-    prefix
-        Each file named in a split should start with this prefix.
-
-    extension
-        Each file named in a split should end with this extension.
-
-    possible_labels
-        These are the list of possible labels contained in any split.
-    """
-
-    assert len(batch) == 2  # data, metadata
-
-    assert isinstance(batch[0], torch.Tensor)
-    assert batch[0].shape[0] == size  # mini-batch size
-    assert batch[0].shape[1] == 1  # grayscale images
-    assert batch[0].shape[2] == batch[0].shape[3]  # image is square
-
-    assert isinstance(batch[1], dict)  # metadata
-    assert len(batch[1]) == 2  # label and name
-
-    assert "label" in batch[1]
-    assert all([k in possible_labels for k in batch[1]["label"]])
-
-    assert "name" in batch[1]
-    assert all([k.startswith(prefix) for k in batch[1]["name"]])
-    assert all([k.endswith(extension) for k in batch[1]["name"]])
-
-
-def test_protocol_consistency():
-    # Cross-validation fold 0-2
-    for k in range(3):
-        _check_split(
-            f"fold-{k}.json",
-            lengths=dict(train=174, validation=44, test=25),
-        )
-
-    # Cross-validation fold 3-9
-    for k in range(3, 10):
-        _check_split(
-            f"fold-{k}.json",
-            lengths=dict(train=175, validation=44, test=24),
-        )
+    database_checkers.check_split(
+        make_split(f"{split}.json"),
+        lengths=lenghts,
+        prefixes=("HIV-TB_Algorithm_study_X-rays",),
+        possible_labels=(0, 1),
+    )
 
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
-def test_loading():
-    from ptbench.data.hivtb.fold_0 import datamodule
+@pytest.mark.parametrize(
+    "dataset",
+    [
+        "train",
+        "validation",
+        "test",
+    ],
+)
+@pytest.mark.parametrize(
+    "name",
+    [
+        "fold_0",
+        "fold_1",
+        "fold_2",
+        "fold_3",
+        "fold_4",
+        "fold_5",
+        "fold_6",
+        "fold_7",
+        "fold_8",
+        "fold_9",
+    ],
+)
+def test_loading(database_checkers, name: str, dataset: str):
+    datamodule = importlib.import_module(
+        f".{name}", "ptbench.data.hivtb"
+    ).datamodule
 
     datamodule.model_transforms = []  # should be done before setup()
     datamodule.setup("predict")  # sets up all datasets
 
-    for loader in datamodule.predict_dataloader().values():
-        limit = 5  # limit load checking
-        for batch in loader:
-            if limit == 0:
-                break
-            _check_loaded_batch(batch)
-            limit -= 1
+    loader = datamodule.predict_dataloader()[dataset]
+
+    limit = 3  # limit load checking
+    for batch in loader:
+        if limit == 0:
+            break
+        database_checkers.check_loaded_batch(
+            batch,
+            batch_size=1,
+            color_planes=1,
+            prefixes=("HIV-TB_Algorithm_study_X-rays",),
+            possible_labels=(0, 1),
+        )
+        limit -= 1
diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py
index ee34d8d09d14379dfc08dcc4adebc33ffedd18b9..7125efb195a058f36c785268c110c765f0890d65 100644
--- a/tests/test_tbpoc.py
+++ b/tests/test_tbpoc.py
@@ -3,127 +3,95 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """Tests for TB-POC dataset."""
 
-import pytest
-import torch
-
-from ptbench.data.tbpoc.datamodule import make_split
-
-
-def _check_split(
-    split_filename: str,
-    lengths: dict[str, int],
-    prefix: str = "TBPOC_CXR/",
-    extension: str = ".jpeg",
-    possible_labels: list[int] = [0, 1],
-):
-    """Runs a simple consistence check on the data split.
-
-    Parameters
-    ----------
-
-    split_filename
-        This is the split we will check
-
-    lenghts
-        A dictionary that contains keys matching those of the split (this will
-        be checked).  The values of the dictionary should correspond to the
-        sizes of each of the datasets in the split.
-
-    prefix
-        Each file named in a split should start with this prefix.
+import importlib
 
-    extension
-        Each file named in a split should end with this extension.
-
-    possible_labels
-        These are the list of possible labels contained in any split.
-    """
-
-    split = make_split(split_filename)
-
-    assert len(split) == len(lengths)
-
-    for k in lengths.keys():
-        # dataset must have been declared
-        assert k in split
-
-        assert len(split[k]) == lengths[k]
-        for s in split[k]:
-            # assert s[0].startswith(prefix)
-            assert s[0].endswith(extension)
-            assert s[1] in possible_labels
+import pytest
 
 
-def _check_loaded_batch(
-    batch,
-    size: int = 1,
-    prefix: str = "TBPOC_CXR/",
-    extension: str = ".jpeg",
-    possible_labels: list[int] = [0, 1],
+def id_function(val):
+    if isinstance(val, dict):
+        return str(val)
+    return repr(val)
+
+
+@pytest.mark.parametrize(
+    "split,lenghts",
+    [
+        ("fold-0", dict(train=292, validation=74, test=41)),
+        ("fold-1", dict(train=292, validation=74, test=41)),
+        ("fold-2", dict(train=292, validation=74, test=41)),
+        ("fold-3", dict(train=292, validation=74, test=41)),
+        ("fold-4", dict(train=292, validation=74, test=41)),
+        ("fold-5", dict(train=292, validation=74, test=41)),
+        ("fold-6", dict(train=292, validation=74, test=41)),
+        ("fold-7", dict(train=293, validation=74, test=40)),
+        ("fold-8", dict(train=293, validation=74, test=40)),
+        ("fold-9", dict(train=293, validation=74, test=40)),
+    ],
+    ids=id_function,  # just changes how pytest prints it
+)
+def test_protocol_consistency(
+    database_checkers, split: str, lenghts: dict[str, int]
 ):
-    """Checks the consistence of an individual (loaded) batch.
-
-    Parameters
-    ----------
-
-    batch
-        The loaded batch to be checked.
-
-    prefix
-        Each file named in a split should start with this prefix.
-
-    extension
-        Each file named in a split should end with this extension.
-
-    possible_labels
-        These are the list of possible labels contained in any split.
-    """
-
-    assert len(batch) == 2  # data, metadata
-
-    assert isinstance(batch[0], torch.Tensor)
-    assert batch[0].shape[0] == size  # mini-batch size
-    assert batch[0].shape[1] == 1  # grayscale images
-    assert batch[0].shape[2] == batch[0].shape[3]  # image is square
-
-    assert isinstance(batch[1], dict)  # metadata
-    assert len(batch[1]) == 2  # label and name
-
-    assert "label" in batch[1]
-    assert all([k in possible_labels for k in batch[1]["label"]])
-
-    assert "name" in batch[1]
-    # assert all([k.startswith(prefix) for k in batch[1]["name"]])
-    assert all([k.endswith(extension) for k in batch[1]["name"]])
-
-
-def test_protocol_consistency():
-    # Cross-validation fold 0-6
-    for k in range(7):
-        _check_split(
-            f"fold-{k}.json",
-            lengths=dict(train=292, validation=74, test=41),
-        )
-
-    # Cross-validation fold 7-9
-    for k in range(7, 10):
-        _check_split(
-            f"fold-{k}.json",
-            lengths=dict(train=293, validation=74, test=40),
-        )
-
-
-@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
-def test_loading():
-    from ptbench.data.tbpoc.fold_0 import datamodule
+    from ptbench.data.tbpoc.datamodule import make_split
+
+    database_checkers.check_split(
+        make_split(f"{split}.json"),
+        lengths=lenghts,
+        prefixes=(
+            "TBPOC_CXR/TBPOC-",
+            "TBPOC_CXR/tbpoc-",
+        ),
+        possible_labels=(0, 1),
+    )
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
+@pytest.mark.parametrize(
+    "dataset",
+    [
+        "train",
+        "validation",
+        "test",
+    ],
+)
+@pytest.mark.parametrize(
+    "name",
+    [
+        "fold_0",
+        "fold_1",
+        "fold_2",
+        "fold_3",
+        "fold_4",
+        "fold_5",
+        "fold_6",
+        "fold_7",
+        "fold_8",
+        "fold_9",
+    ],
+)
+def test_loading(database_checkers, name: str, dataset: str):
+    datamodule = importlib.import_module(
+        f".{name}", "ptbench.data.tbpoc"
+    ).datamodule
 
     datamodule.model_transforms = []  # should be done before setup()
     datamodule.setup("predict")  # sets up all datasets
 
-    for loader in datamodule.predict_dataloader().values():
-        limit = 5  # limit load checking
-        for batch in loader:
-            if limit == 0:
-                break
-            _check_loaded_batch(batch)
-            limit -= 1
+    loader = datamodule.predict_dataloader()[dataset]
+
+    limit = 3  # limit load checking
+    for batch in loader:
+        if limit == 0:
+            break
+        database_checkers.check_loaded_batch(
+            batch,
+            batch_size=1,
+            color_planes=1,
+            prefixes=(
+                "TBPOC_CXR/TBPOC-",
+                "TBPOC_CXR/tbpoc-",
+            ),
+            possible_labels=(0, 1),
+        )
+        limit -= 1