diff --git a/tests/test_ch.py b/tests/test_ch.py
index 853a2f7184fc56ebd805ea278a4cc21f8ecbad21..510b1171741a9a4ea0176976a7aa45dddaa427fe 100644
--- a/tests/test_ch.py
+++ b/tests/test_ch.py
@@ -4,133 +4,194 @@
 
 """Tests for Shenzhen dataset."""
 
-import pytest
+import importlib
 
-from ptbench.data.shenzhen import dataset
+import pytest
 
 
 def test_protocol_consistency():
     # Default protocol
-    subset = dataset.subsets("default")
+
+    datamodule = importlib.import_module(
+        "ptbench.data.shenzhen.default"
+    ).datamodule
+    subset = datamodule.database_split.subsets
+
     assert len(subset) == 3
 
     assert "train" in subset
     assert len(subset["train"]) == 422
     for s in subset["train"]:
-        assert s.key.startswith("CXR_png/CHNCXR_0")
+        assert s[0].startswith("CXR_png/CHNCXR_0")
 
     assert "validation" in subset
     assert len(subset["validation"]) == 107
     for s in subset["validation"]:
-        assert s.key.startswith("CXR_png/CHNCXR_0")
+        assert s[0].startswith("CXR_png/CHNCXR_0")
 
     assert "test" in subset
     assert len(subset["test"]) == 133
     for s in subset["test"]:
-        assert s.key.startswith("CXR_png/CHNCXR_0")
+        assert s[0].startswith("CXR_png/CHNCXR_0")
 
     # Check labels
     for s in subset["train"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     for s in subset["validation"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     for s in subset["test"]:
-        assert s.label in [0.0, 1.0]
+        assert s[1] in [0.0, 1.0]
 
     # Cross-validation folds 0-1
     for f in range(2):
-        subset = dataset.subsets("fold_" + str(f))
+        datamodule = importlib.import_module(
+            f"ptbench.data.shenzhen.fold_{str(f)}"
+        ).datamodule
+        subset = datamodule.database_split.subsets
+
         assert len(subset) == 3
 
         assert "train" in subset
         assert len(subset["train"]) == 476
         for s in subset["train"]:
-            assert s.key.startswith("CXR_png/CHNCXR_0")
+            assert s[0].startswith("CXR_png/CHNCXR_0")
 
         assert "validation" in subset
         assert len(subset["validation"]) == 119
         for s in subset["validation"]:
-            assert s.key.startswith("CXR_png/CHNCXR_0")
+            assert s[0].startswith("CXR_png/CHNCXR_0")
 
         assert "test" in subset
         assert len(subset["test"]) == 67
         for s in subset["test"]:
-            assert s.key.startswith("CXR_png/CHNCXR_0")
+            assert s[0].startswith("CXR_png/CHNCXR_0")
 
         # Check labels
         for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
     # Cross-validation folds 2-9
     for f in range(2, 10):
-        subset = dataset.subsets("fold_" + str(f))
+        datamodule = importlib.import_module(
+            f"ptbench.data.shenzhen.fold_{str(f)}"
+        ).datamodule
+        subset = datamodule.database_split.subsets
+
         assert len(subset) == 3
 
         assert "train" in subset
         assert len(subset["train"]) == 476
         for s in subset["train"]:
-            assert s.key.startswith("CXR_png/CHNCXR_0")
+            assert s[0].startswith("CXR_png/CHNCXR_0")
 
         assert "validation" in subset
         assert len(subset["validation"]) == 120
         for s in subset["validation"]:
-            assert s.key.startswith("CXR_png/CHNCXR_0")
+            assert s[0].startswith("CXR_png/CHNCXR_0")
 
         assert "test" in subset
         assert len(subset["test"]) == 66
         for s in subset["test"]:
-            assert s.key.startswith("CXR_png/CHNCXR_0")
+            assert s[0].startswith("CXR_png/CHNCXR_0")
 
         # Check labels
         for s in subset["train"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["validation"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
         for s in subset["test"]:
-            assert s.label in [0.0, 1.0]
+            assert s[1] in [0.0, 1.0]
 
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
 def test_loading():
-    def _check_size(size):
-        if (
-            size[0] >= 1130
-            and size[0] <= 3001
-            and size[1] >= 948
-            and size[1] <= 3001
-        ):
+    import torch
+    import torchvision.transforms
+
+    from ptbench.data.datamodule import _DelayedLoadingDataset
+
+    def _check_size(shape):
+        if shape[0] == 1 and shape[1] == 512 and shape[2] == 512:
             return True
         return False
 
     def _check_sample(s):
-        data = s.data
-        assert isinstance(data, dict)
-        assert len(data) == 2
+        assert len(s) == 2
 
-        assert "data" in data
-        assert _check_size(data["data"].size)  # Check size
-        assert data["data"].mode == "L"  # Check colors
+        data = s[0]
+        metadata = s[1]
 
-        assert "label" in data
-        assert data["label"] in [0, 1]  # Check labels
+        assert isinstance(data, torch.Tensor)
+
+        print(data.shape)
+        assert _check_size(data.shape)  # Check size
+
+        assert (
+            torchvision.transforms.ToPILImage()(data).mode == "L"
+        )  # Check colors
+
+        assert "label" in metadata
+        assert metadata["label"] in [0, 1]  # Check labels
 
     limit = 30  # use this to limit testing to first images only, else None
 
-    subset = dataset.subsets("default")
-    for s in subset["train"][:limit]:
+    datamodule = importlib.import_module(
+        "ptbench.data.shenzhen.default"
+    ).datamodule
+    subset = datamodule.database_split.subsets
+    raw_data_loader = datamodule.raw_data_loader
+
+    # Need to use private function so we can limit the number of samples to use
+    dataset = _DelayedLoadingDataset(
+        subset["train"][:limit],
+        raw_data_loader,
+    )
+
+    for s in dataset:
         _check_sample(s)
 
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
 def test_check():
-    assert dataset.check() == 0
+    from ptbench.data.split import check_database_split_loading
+
+    limit = 30  # use this to limit testing to first images only, else 0
+
+    # Default protocol
+    datamodule = importlib.import_module(
+        "ptbench.data.shenzhen.default"
+    ).datamodule
+    database_split = datamodule.database_split
+    raw_data_loader = datamodule.raw_data_loader
+
+    assert (
+        check_database_split_loading(
+            database_split, raw_data_loader, limit=limit
+        )
+        == 0
+    )
+
+    # Folds
+    for f in range(10):
+        datamodule = importlib.import_module(
+            f"ptbench.data.shenzhen.fold_{f}"
+        ).datamodule
+        database_split = datamodule.database_split
+        raw_data_loader = datamodule.raw_data_loader
+
+        assert (
+            check_database_split_loading(
+                database_split, raw_data_loader, limit=limit
+            )
+            == 0
+        )