diff --git a/tests/conftest.py b/tests/conftest.py
index febcc24f681f97b2c7002d05e1bf58b658f90926..40276ff1f2db47c06e3c1c6892f984fc8230cdb3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -160,9 +160,19 @@ class DatabaseCheckers:
 
         Parameters
         ----------
+<<<<<<< HEAD
         split
             An instance of DatabaseSplit.
         lengths
+=======
+
+        make_split
+            A database specific function that takes a split name and returns
+            the loaded database split.
+        split_filename
+            This is the split we will check.
+        lenghts
+>>>>>>> 91bcad6 ([test] Add checks for specific image shapes)
             A dictionary that contains keys matching those of the split (this will
             be checked).  The values of the dictionary should correspond to the
             sizes of each of the datasets in the split.
@@ -197,13 +207,13 @@ class DatabaseCheckers:
         color_planes: int,
         prefixes: typing.Sequence[str],
         possible_labels: typing.Sequence[int],
-        expected_num_labels: typing.Optional[int] = None,
+        expected_num_labels: int,
+        expected_image_shape: typing.Optional[tuple[int, ...]] = None,
     ):
         """Check the consistency of an individual (loaded) batch.
 
         Parameters
         ----------
-
         batch
             The loaded batch to be checked.
         batch_size
@@ -215,15 +225,24 @@ class DatabaseCheckers:
             prefixes.
         possible_labels
             These are the list of possible labels contained in any split.
+        expected_num_labels
+            The expected number of labels each sample should have.
+        expected_image_shape
+            The expected shape of the image (num_channels, width, height).
         """
 
         assert len(batch) == 2  # data, metadata
 
         assert isinstance(batch[0], torch.Tensor)
         assert batch[0].shape[0] == batch_size  # mini-batch size
-        assert batch[0].shape[1] == color_planes  # grayscale images
+        assert batch[0].shape[1] == color_planes
         assert batch[0].shape[2] == batch[0].shape[3]  # image is square
 
+        if expected_image_shape:
+            assert all(
+                [data.shape == expected_image_shape for data in batch[0]]
+            )
+
         assert isinstance(batch[1], dict)  # metadata
         assert len(batch[1]) == 2  # label and name
 
diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py
index 8188c7cb3d3650b71a445fba484b40369c283f52..e6ad485dbfc5569dcf494abc886bfdf90b89b4c0 100644
--- a/tests/test_nih_cxr14.py
+++ b/tests/test_nih_cxr14.py
@@ -35,22 +35,18 @@ def test_protocol_consistency(
     )
 
 
-@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
-@pytest.mark.parametrize(
-    "dataset",
-    [
-        "train",
-        "validation",
-        "test",
-    ],
-)
-@pytest.mark.parametrize(
-    "name",
-    [
-        "default",
-    ],
-)
-def test_loading(database_checkers, name: str, dataset: str):
+testdata = [
+    ("default", "train", 14),
+    ("default", "validation", 14),
+    ("default", "test", 14),
+    ("cardiomegaly", "train", 14),
+    ("cardiomegaly", "validation", 14),
+]
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
+@pytest.mark.parametrize("name,dataset,num_labels", testdata)
+def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
     datamodule = importlib.import_module(
         f".{name}", "mednet.config.data.nih_cxr14"
     ).datamodule
@@ -70,9 +66,10 @@ def test_loading(database_checkers, name: str, dataset: str):
             color_planes=1,
             prefixes=("images/000",),
             possible_labels=(0, 1),
+            expected_num_labels=num_labels,
+            expected_image_shape=(1, 1024, 1024),
         )
         limit -= 1
 
 
 # TODO: check size 1024x1024
-# TODO: check there are 14 binary labels (0, 1)
diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py
index 231982e694f3794d2e28264ab242ebb2cf19f717..6c022dbaaf99357b8acea8843a2947fc7cebc9ce 100644
--- a/tests/test_tbx11k.py
+++ b/tests/test_tbx11k.py
@@ -151,14 +151,16 @@ def test_protocol_consistency(
 def check_loaded_batch(
     batch,
     batch_size: int,
+    color_planes: int,
     prefixes: typing.Sequence[str],
-    expected_num_labels: typing.Optional[int] = None,
+    possible_labels: typing.Sequence[int],
+    expected_num_labels: int,
+    expected_image_shape: typing.Optional[tuple[int, ...]] = None,
 ):
     """Check the consistency of an individual (loaded) batch.
 
     Parameters
     ----------
-
     batch
         The loaded batch to be checked.
     batch_size
@@ -172,9 +174,11 @@ def check_loaded_batch(
 
     assert isinstance(batch[0], torch.Tensor)
     assert batch[0].shape[0] == batch_size  # mini-batch size
-    assert batch[0].shape[1] == 3  # grayscale images
+    assert batch[0].shape[1] == color_planes
     assert batch[0].shape[2] == batch[0].shape[3]  # image is square
-    assert batch[0].shape[2] == 512  # image is 512 pixels large
+
+    if expected_image_shape:
+        assert all([data.shape == expected_image_shape for data in batch[0]])
 
     assert isinstance(batch[1], dict)  # metadata
     assert (
@@ -182,7 +186,7 @@ def check_loaded_batch(
     )  # label, name and radiological sign bounding-boxes
 
     assert "label" in batch[1]
-    assert all([k in (0, 1) for k in batch[1]["label"]])
+    assert all([k in possible_labels for k in batch[1]["label"]])
 
     if expected_num_labels:
         assert len(batch[1]["label"]) == expected_num_labels
@@ -272,7 +276,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
         check_loaded_batch(
             batch,
             batch_size=1,
+            color_planes=3,
             prefixes=prefixes,
+            possible_labels=(0, 1),
             expected_num_labels=1,
+            expected_image_shape=(3, 512, 512),
         )
         limit -= 1