From ab96b3aaed27d5c27e2a18ab8dd5d8bf59cdc970 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 26 Jan 2023 18:12:49 +0100
Subject: [PATCH] [tests] Fix some of the tests after last commit

---
 .../datasets/nih_cxr14_re_pc/__init__.py      |  4 ++-
 tests/test_nih_cxr14.py                       | 32 -------------------
 2 files changed, 3 insertions(+), 33 deletions(-)

diff --git a/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py b/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py
index 152c5be7..7b6d0df2 100644
--- a/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py
+++ b/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py
@@ -7,8 +7,10 @@ from torch.utils.data.dataset import ConcatDataset
 
 def _maker(protocol):
     if protocol == "idiap":
-        from ..nih_cxr14_re import idiap as nih_cxr14_re
+        from ..nih_cxr14_re import default as nih_cxr14_re
         from ..padchest import no_tb_idiap as padchest_no_tb
+    else:
+        raise RuntimeError(f"Unsupported protocol: {protocol}")
 
     nih_cxr14_re = nih_cxr14_re.dataset
     padchest_no_tb = padchest_no_tb.dataset
diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py
index d47d4dc5..b2c7f5aa 100644
--- a/tests/test_nih_cxr14.py
+++ b/tests/test_nih_cxr14.py
@@ -42,38 +42,6 @@ def test_protocol_consistency():
         for element in list(set(s.label)):
             assert element in [0.0, 1.0]
 
-    # Idiap protocol
-    subset = dataset.subsets("idiap")
-    assert len(subset) == 3
-
-    assert "train" in subset
-    assert len(subset["train"]) == 98637
-    for s in subset["train"]:
-        assert s.key.startswith("images/000")
-
-    assert "validation" in subset
-    assert len(subset["validation"]) == 6350
-    for s in subset["validation"]:
-        assert s.key.startswith("images/000")
-
-    assert "test" in subset
-    assert len(subset["test"]) == 4054
-    for s in subset["test"]:
-        assert s.key.startswith("images/000")
-
-    # Check labels
-    for s in subset["train"]:
-        for element in list(set(s.label)):
-            assert element in [0.0, 1.0]
-
-    for s in subset["validation"]:
-        for element in list(set(s.label)):
-            assert element in [0.0, 1.0]
-
-    for s in subset["test"]:
-        for element in list(set(s.label)):
-            assert element in [0.0, 1.0]
-
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
 def test_loading():
-- 
GitLab