From ab96b3aaed27d5c27e2a18ab8dd5d8bf59cdc970 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 26 Jan 2023 18:12:49 +0100 Subject: [PATCH] [tests] Fix some of the tests after last commit --- .../datasets/nih_cxr14_re_pc/__init__.py | 4 ++- tests/test_nih_cxr14.py | 32 ------------------- 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py b/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py index 152c5be7..7b6d0df2 100644 --- a/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py +++ b/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py @@ -7,8 +7,10 @@ from torch.utils.data.dataset import ConcatDataset def _maker(protocol): if protocol == "idiap": - from ..nih_cxr14_re import idiap as nih_cxr14_re + from ..nih_cxr14_re import default as nih_cxr14_re from ..padchest import no_tb_idiap as padchest_no_tb + else: + raise RuntimeError(f"Unsupported protocol: {protocol}") nih_cxr14_re = nih_cxr14_re.dataset padchest_no_tb = padchest_no_tb.dataset diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py index d47d4dc5..b2c7f5aa 100644 --- a/tests/test_nih_cxr14.py +++ b/tests/test_nih_cxr14.py @@ -42,38 +42,6 @@ def test_protocol_consistency(): for element in list(set(s.label)): assert element in [0.0, 1.0] - # Idiap protocol - subset = dataset.subsets("idiap") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 98637 - for s in subset["train"]: - assert s.key.startswith("images/000") - - assert "validation" in subset - assert len(subset["validation"]) == 6350 - for s in subset["validation"]: - assert s.key.startswith("images/000") - - assert "test" in subset - assert len(subset["test"]) == 4054 - for s in subset["test"]: - assert s.key.startswith("images/000") - - # Check labels - for s in subset["train"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] - - for s in subset["validation"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] - - for s in subset["test"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] - @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re") def test_loading(): -- GitLab