diff --git a/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py b/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py index 152c5be7920d202a718af5dfbd0d99c583c83768..7b6d0df211ab0b5bd52db6f3887470373603ed8f 100644 --- a/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py +++ b/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py @@ -7,8 +7,10 @@ from torch.utils.data.dataset import ConcatDataset def _maker(protocol): if protocol == "idiap": - from ..nih_cxr14_re import idiap as nih_cxr14_re + from ..nih_cxr14_re import default as nih_cxr14_re from ..padchest import no_tb_idiap as padchest_no_tb + else: + raise RuntimeError(f"Unsupported protocol: {protocol}") nih_cxr14_re = nih_cxr14_re.dataset padchest_no_tb = padchest_no_tb.dataset diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py index d47d4dc5b1aee3c73975346abe0a2b1ff3519c30..b2c7f5aa5f812ce6dc4bdd161ee48dcf05a3c92c 100644 --- a/tests/test_nih_cxr14.py +++ b/tests/test_nih_cxr14.py @@ -42,38 +42,6 @@ def test_protocol_consistency(): for element in list(set(s.label)): assert element in [0.0, 1.0] - # Idiap protocol - subset = dataset.subsets("idiap") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 98637 - for s in subset["train"]: - assert s.key.startswith("images/000") - - assert "validation" in subset - assert len(subset["validation"]) == 6350 - for s in subset["validation"]: - assert s.key.startswith("images/000") - - assert "test" in subset - assert len(subset["test"]) == 4054 - for s in subset["test"]: - assert s.key.startswith("images/000") - - # Check labels - for s in subset["train"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] - - for s in subset["validation"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] - - for s in subset["test"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] - @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re") def test_loading():