diff --git a/src/ptbench/data/indian/datamodule.py b/src/ptbench/data/indian/datamodule.py
index f68d184e4d10b0830bb3ef50d83527b4f2d88ac2..f6017cad4e7deac52c74c8478c3697f64dffa9c7 100644
--- a/src/ptbench/data/indian/datamodule.py
+++ b/src/ptbench/data/indian/datamodule.py
@@ -55,5 +55,5 @@ class DataModule(CachingDataModule):
     def __init__(self, split_filename: str):
         super().__init__(
             database_split=make_split(split_filename),
-            raw_data_loader=RawDataLoader(),
+            raw_data_loader=RawDataLoader(config_variable="datadir.indian"),
         )
diff --git a/src/ptbench/data/shenzhen/datamodule.py b/src/ptbench/data/shenzhen/datamodule.py
index 7cf1833b27ca1e79382cb7e45fa41b5b4ee27292..0596007eaae5050b5691f1ebe1563f78e1507910 100644
--- a/src/ptbench/data/shenzhen/datamodule.py
+++ b/src/ptbench/data/shenzhen/datamodule.py
@@ -34,9 +34,9 @@ class RawDataLoader(_BaseRawDataLoader):
 
     datadir: str
 
-    def __init__(self):
+    def __init__(self, config_variable: str = "datadir.shenzhen"):
         self.datadir = load_rc().get(
-            "datadir.shenzhen", os.path.realpath(os.curdir)
+            config_variable, os.path.realpath(os.curdir)
         )
 
     def sample(self, sample: tuple[str, int]) -> Sample:
diff --git a/tests/test_indian.py b/tests/test_indian.py
index 87660c1af4073fde09cc13d5a6f616d9180a3e4e..91adf0d42083cfa45869e41ccbe9fc1923edd039 100644
--- a/tests/test_indian.py
+++ b/tests/test_indian.py
@@ -1,142 +1,127 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-"""Tests for Indian dataset."""
-
-import pytest
+"""Tests for Indian (a.k.a.
 
+database A/database B) dataset.
+"""
 
-@pytest.mark.skip(reason="Test need to be updated")
-def test_protocol_consistency():
-    from ptbench.data.indian import dataset
+import pytest
+import torch
 
-    # Default protocol
-    subset = dataset.subsets("default")
-    assert len(subset) == 3
+from ptbench.data.indian.datamodule import make_split
 
-    assert "train" in subset
-    assert len(subset["train"]) == 83
-    for s in subset["train"]:
-        assert s.key.startswith("DatasetA/Training/")
 
-    assert "validation" in subset
-    assert len(subset["validation"]) == 20
-    for s in subset["validation"]:
-        assert s.key.startswith("DatasetA/Training/")
+def _check_split(
+    split_filename: str,
+    lengths: dict[str, int],
+    prefix: str = "Dataset",
+    possible_labels: list[int] = [0, 1],
+):
+    """Runs a simple consistence check on the data split.
 
-    assert "test" in subset
-    assert len(subset["test"]) == 52
-    for s in subset["test"]:
-        assert s.key.startswith("DatasetA/Testing/")
+    Parameters
+    ----------
 
-    # Check labels
-    for s in subset["train"]:
-        assert s.label in [0.0, 1.0]
+    split_filename
+        This is the split we will check
 
-    for s in subset["validation"]:
-        assert s.label in [0.0, 1.0]
+    lenghts
+        A dictionary that contains keys matching those of the split (this will
+        be checked).  The values of the dictionary should correspond to the
+        sizes of each of the datasets in the split.
 
-    for s in subset["test"]:
-        assert s.label in [0.0, 1.0]
+    prefix
+        Each file named in a split should start with this prefix.
 
-    # Cross-validation fold 0-4
-    for f in range(5):
-        subset = dataset.subsets("fold_" + str(f))
-        assert len(subset) == 3
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-        assert "train" in subset
-        assert len(subset["train"]) == 111
-        for s in subset["train"]:
-            assert s.key.startswith("DatasetA")
+    split = make_split(split_filename)
 
-        assert "validation" in subset
-        assert len(subset["validation"]) == 28
-        for s in subset["validation"]:
-            assert s.key.startswith("DatasetA")
+    assert len(split) == len(lengths)
 
-        assert "test" in subset
-        assert len(subset["test"]) == 16
-        for s in subset["test"]:
-            assert s.key.startswith("DatasetA")
+    for k in lengths.keys():
+        # dataset must have been declared
+        assert k in split
 
-        # Check labels
-        for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+        assert len(split[k]) == lengths[k]
+        for s in split[k]:
+            assert s[0].startswith(prefix)
+            assert s[1] in possible_labels
 
-        for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
 
-        for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+def _check_loaded_batch(
+    batch,
+    size: int = 1,
+    prefix: str = "Dataset",
+    possible_labels: list[int] = [0, 1],
+):
+    """Checks the consistence of an individual (loaded) batch.
 
-    # Cross-validation fold 5-9
-    for f in range(5, 10):
-        subset = dataset.subsets("fold_" + str(f))
-        assert len(subset) == 3
+    Parameters
+    ----------
 
-        assert "train" in subset
-        assert len(subset["train"]) == 112
-        for s in subset["train"]:
-            assert s.key.startswith("DatasetA")
+    batch
+        The loaded batch to be checked.
 
-        assert "validation" in subset
-        assert len(subset["validation"]) == 28
-        for s in subset["validation"]:
-            assert s.key.startswith("DatasetA")
+    prefix
+        Each file named in a split should start with this prefix.
 
-        assert "test" in subset
-        assert len(subset["test"]) == 15
-        for s in subset["test"]:
-            assert s.key.startswith("DatasetA")
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-        # Check labels
-        for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+    assert len(batch) == 2  # data, metadata
 
-        for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+    assert isinstance(batch[0], torch.Tensor)
+    assert batch[0].shape[0] == size  # mini-batch size
+    assert batch[0].shape[1] == 1  # grayscale images
+    assert batch[0].shape[2] == batch[0].shape[3]  # image is square
 
-        for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+    assert isinstance(batch[1], dict)  # metadata
+    assert len(batch[1]) == 2  # label and name
 
+    assert "label" in batch[1]
+    assert all([k in possible_labels for k in batch[1]["label"]])
 
-@pytest.mark.skip(reason="Test need to be updated")
-@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
-def test_loading():
-    from ptbench.data.indian import dataset
+    assert "name" in batch[1]
+    assert all([k.startswith(prefix) for k in batch[1]["name"]])
 
-    def _check_size(size):
-        if (
-            size[0] >= 1024
-            and size[0] <= 2320
-            and size[1] >= 1024
-            and size[1] <= 2828
-        ):
-            return True
-        return False
 
-    def _check_sample(s):
-        data = s.data
-        assert isinstance(data, dict)
-        assert len(data) == 2
-
-        assert "data" in data
-        assert _check_size(data["data"].size)  # Check size
-        assert data["data"].mode == "L"  # Check colors
-
-        assert "label" in data
-        assert data["label"] in [0, 1]  # Check labels
+def test_protocol_consistency():
+    _check_split(
+        "default.json",
+        lengths=dict(train=83, validation=20, test=52),
+    )
 
-    limit = 30  # use this to limit testing to first images only, else None
+    # Cross-validation fold 0-4
+    for k in range(5):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=111, validation=28, test=16),
+        )
 
-    subset = dataset.subsets("default")
-    for s in subset["train"][:limit]:
-        _check_sample(s)
+    # Cross-validation fold 5-9
+    for k in range(5, 10):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=112, validation=28, test=15),
+        )
 
 
-@pytest.mark.skip(reason="Test need to be updated")
 @pytest.mark.skip_if_rc_var_not_set("datadir.indian")
-def test_check():
-    from ptbench.data.indian import dataset
-
-    assert dataset.check() == 0
+def test_loading():
+    from ptbench.data.indian.default import datamodule
+
+    datamodule.model_transforms = []  # should be done before setup()
+    datamodule.setup("predict")  # sets up all datasets
+
+    for loader in datamodule.predict_dataloader().values():
+        limit = 5  # limit load checking
+        for batch in loader:
+            if limit == 0:
+                break
+            _check_loaded_batch(batch)
+            limit -= 1
diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py
index 6e577081a83a075c92a4808bbfb1c876d247fdde..30c69543a77d24715f2a815192d2e0abac41b9c5 100644
--- a/tests/test_shenzhen.py
+++ b/tests/test_shenzhen.py
@@ -3,202 +3,122 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """Tests for Shenzhen dataset."""
 
-import importlib
-
 import pytest
+import torch
 
+from ptbench.data.shenzhen.datamodule import make_split
 
-def test_protocol_consistency():
-    # Default protocol
-
-    datamodule = getattr(
-        importlib.import_module("ptbench.data.shenzhen.datamodules"), "default"
-    )
-
-    subset = datamodule.splits
-
-    assert len(subset) == 3
-
-    assert "train" in subset
-    train_samples = subset["train"][0][0]
-    assert len(train_samples) == 422
-    for s in train_samples:
-        assert s[0].startswith("CXR_png/CHNCXR_0")
-
-    assert "validation" in subset
-    validation_samples = subset["validation"][0][0]
-    assert len(validation_samples) == 107
-    for s in validation_samples:
-        assert s[0].startswith("CXR_png/CHNCXR_0")
-
-    assert "test" in subset
-    test_samples = subset["test"][0][0]
-    assert len(test_samples) == 133
-    for s in test_samples:
-        assert s[0].startswith("CXR_png/CHNCXR_0")
 
-    # Check labels
-    for s in train_samples:
-        assert s[1] in [0.0, 1.0]
+def _check_split(
+    split_filename: str,
+    lengths: dict[str, int],
+    prefix: str = "CXR_png/CHNCXR_0",
+    possible_labels: list[int] = [0, 1],
+):
+    """Runs a simple consistence check on the data split.
 
-    for s in validation_samples:
-        assert s[1] in [0.0, 1.0]
+    Parameters
+    ----------
 
-    for s in test_samples:
-        assert s[1] in [0.0, 1.0]
+    split_filename
+        This is the split we will check
 
-    # Cross-validation folds 0-1
-    for f in range(2):
-        datamodule = getattr(
-            importlib.import_module("ptbench.data.shenzhen.datamodules"),
-            f"fold_{str(f)}",
-        )
-
-        subset = datamodule.splits
-
-        assert len(subset) == 3
+    lenghts
+        A dictionary that contains keys matching those of the split (this will
+        be checked).  The values of the dictionary should correspond to the
+        sizes of each of the datasets in the split.
 
-        assert "train" in subset
-        train_samples = subset["train"][0][0]
-        assert len(train_samples) == 476
-        for s in train_samples:
-            assert s[0].startswith("CXR_png/CHNCXR_0")
+    prefix
+        Each file named in a split should start with this prefix.
 
-        assert "validation" in subset
-        validation_samples = subset["validation"][0][0]
-        assert len(validation_samples) == 119
-        for s in validation_samples:
-            assert s[0].startswith("CXR_png/CHNCXR_0")
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-        assert "test" in subset
-        test_samples = subset["test"][0][0]
-        assert len(test_samples) == 67
-        for s in test_samples:
-            assert s[0].startswith("CXR_png/CHNCXR_0")
+    split = make_split(split_filename)
 
-        # Check labels
-        for s in train_samples:
-            assert s[1] in [0.0, 1.0]
+    assert len(split) == len(lengths)
 
-        for s in validation_samples:
-            assert s[1] in [0.0, 1.0]
+    for k in lengths.keys():
+        # dataset must have been declared
+        assert k in split
 
-        for s in test_samples:
-            assert s[1] in [0.0, 1.0]
-
-    # Cross-validation folds 2-9
-    for f in range(2, 10):
-        datamodule = getattr(
-            importlib.import_module("ptbench.data.shenzhen.datamodules"),
-            f"fold_{str(f)}",
-        )
+        assert len(split[k]) == lengths[k]
+        for s in split[k]:
+            assert s[0].startswith(prefix)
+            assert s[1] in possible_labels
 
-        subset = datamodule.splits
 
-        assert len(subset) == 3
+def _check_loaded_batch(
+    batch,
+    size: int = 1,
+    prefix: str = "CXR_png/CHNCXR_0",
+    possible_labels: list[int] = [0, 1],
+):
+    """Checks the consistence of an individual (loaded) batch.
 
-        assert "train" in subset
-        train_samples = subset["train"][0][0]
-        assert len(train_samples) == 476
-        for s in train_samples:
-            assert s[0].startswith("CXR_png/CHNCXR_0")
+    Parameters
+    ----------
 
-        assert "validation" in subset
-        validation_samples = subset["validation"][0][0]
-        assert len(validation_samples) == 120
-        for s in validation_samples:
-            assert s[0].startswith("CXR_png/CHNCXR_0")
+    batch
+        The loaded batch to be checked.
 
-        assert "test" in subset
-        test_samples = subset["test"][0][0]
-        assert len(test_samples) == 66
-        for s in test_samples:
-            assert s[0].startswith("CXR_png/CHNCXR_0")
+    prefix
+        Each file named in a split should start with this prefix.
 
-        # Check labels
-        for s in train_samples:
-            assert s[1] in [0.0, 1.0]
+    possible_labels
+        These are the list of possible labels contained in any split.
+    """
 
-        for s in validation_samples:
-            assert s[1] in [0.0, 1.0]
+    assert len(batch) == 2  # data, metadata
 
-        for s in test_samples:
-            assert s[1] in [0.0, 1.0]
-
-
-@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
-def test_loading():
-    import torch
-    import torchvision.transforms
+    assert isinstance(batch[0], torch.Tensor)
+    assert batch[0].shape[0] == size  # mini-batch size
+    assert batch[0].shape[1] == 1  # grayscale images
+    assert batch[0].shape[2] == batch[0].shape[3]  # image is square
 
-    from ptbench.data.datamodule import _DelayedLoadingDataset
+    assert isinstance(batch[1], dict)  # metadata
+    assert len(batch[1]) == 2  # label and name
 
-    def _check_sample(s):
-        assert len(s) == 2
+    assert "label" in batch[1]
+    assert all([k in possible_labels for k in batch[1]["label"]])
 
-        data = s[0]
-        metadata = s[1]
+    assert "name" in batch[1]
+    assert all([k.startswith(prefix) for k in batch[1]["name"]])
 
-        assert isinstance(data, torch.Tensor)
 
-        assert data.size(0) == 1  # check 1 channel
-        assert data.size(1) == data.size(2)  # check square image
-
-        assert (
-            torchvision.transforms.ToPILImage()(data).mode == "L"
-        )  # Check colors
-
-        assert "label" in metadata
-        assert metadata["label"] in [0, 1]  # Check labels
-
-    limit = 30  # use this to limit testing to first images only, else None
-
-    module = importlib.import_module("ptbench.data.shenzhen.datamodules")
-    datamodule = getattr(module, "default")
-    raw_data_loader = module.RawDataLoader()
-    subset = datamodule.splits
-
-    # Need to use private function so we can limit the number of samples to use
-    dataset = _DelayedLoadingDataset(
-        subset["train"][0][0][:limit],
-        raw_data_loader,
+def test_protocol_consistency():
+    _check_split(
+        "default.json",
+        lengths=dict(train=422, validation=107, test=133),
     )
 
-    for s in dataset:
-        _check_sample(s)
-
-
-@pytest.mark.skip(reason="Test need to be updated")
-@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
-def test_check():
-    from ptbench.data.split import check_database_split_loading
-
-    limit = 30  # use this to limit testing to first images only, else 0
-
-    # Default protocol
-    module = importlib.import_module("ptbench.data.shenzhen.datamodules")
-    datamodule = getattr(module, "default")
-    database_split = datamodule.splits
-    raw_data_loader = module.RawDataLoader()
-
-    assert (
-        check_database_split_loading(
-            database_split, raw_data_loader, limit=limit
+    # Cross-validation fold 0-1
+    for k in range(2):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=476, validation=119, test=67),
         )
-        == 0
-    )
 
-    # Folds
-    for f in range(10):
-        module = importlib.import_module("ptbench.data.shenzhen.datamodules")
-        datamodule = getattr(module, f"fold_{f}")
+    # Cross-validation fold 2-9
+    for k in range(2, 10):
+        _check_split(
+            f"fold-{k}.json",
+            lengths=dict(train=476, validation=120, test=66),
+        )
 
-        database_split = datamodule.splits
-        raw_data_loader = module.RawDataLoader()
 
-        assert (
-            check_database_split_loading(
-                database_split, raw_data_loader, limit=limit
-            )
-            == 0
-        )
+@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
+def test_loading():
+    from ptbench.data.shenzhen.default import datamodule
+
+    datamodule.model_transforms = []  # should be done before setup()
+    datamodule.setup("predict")  # sets up all datasets
+
+    for loader in datamodule.predict_dataloader().values():
+        limit = 5  # limit load checking
+        for batch in loader:
+            if limit == 0:
+                break
+            _check_loaded_batch(batch)
+            limit -= 1