diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py
index 242ac7e9c4d9b3dafeb3ee870fce27a7afdc4449..0870c99abf98a361d5dfa5048a166c2773563e91 100644
--- a/tests/test_nih_cxr14.py
+++ b/tests/test_nih_cxr14.py
@@ -17,7 +17,8 @@ def id_function(val):
 @pytest.mark.parametrize(
     "split,lenghts",
     [
-        ("default", dict(train=98637, validation=6350, test=4054)),
+        ("default.json.bz2", dict(train=98637, validation=6350, test=4054)),
+        ("cardiomegaly.json", dict(train=40, validation=40)),
     ],
     ids=id_function,  # just changes how pytest prints it
 )
@@ -27,7 +28,7 @@ def test_protocol_consistency(
     from ptbench.data.nih_cxr14.datamodule import make_split
 
     database_checkers.check_split(
-        make_split(f"{split}.json.bz2"),
+        make_split(split),
         lengths=lenghts,
         prefixes=("images/000",),
         possible_labels=(0, 1),