Skip to content
Snippets Groups Projects
Commit ab96b3aa authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Fix some of the tests after last commit

parent d53a621f
No related branches found
No related tags found
No related merge requests found
Pipeline #69128 canceled
......@@ -7,8 +7,10 @@ from torch.utils.data.dataset import ConcatDataset
def _maker(protocol):
if protocol == "idiap":
from ..nih_cxr14_re import idiap as nih_cxr14_re
from ..nih_cxr14_re import default as nih_cxr14_re
from ..padchest import no_tb_idiap as padchest_no_tb
else:
raise RuntimeError(f"Unsupported protocol: {protocol}")
nih_cxr14_re = nih_cxr14_re.dataset
padchest_no_tb = padchest_no_tb.dataset
......
......@@ -42,38 +42,6 @@ def test_protocol_consistency():
for element in list(set(s.label)):
assert element in [0.0, 1.0]
# Idiap protocol
subset = dataset.subsets("idiap")
assert len(subset) == 3
assert "train" in subset
assert len(subset["train"]) == 98637
for s in subset["train"]:
assert s.key.startswith("images/000")
assert "validation" in subset
assert len(subset["validation"]) == 6350
for s in subset["validation"]:
assert s.key.startswith("images/000")
assert "test" in subset
assert len(subset["test"]) == 4054
for s in subset["test"]:
assert s.key.startswith("images/000")
# Check labels
for s in subset["train"]:
for element in list(set(s.label)):
assert element in [0.0, 1.0]
for s in subset["validation"]:
for element in list(set(s.label)):
assert element in [0.0, 1.0]
for s in subset["test"]:
for element in list(set(s.label)):
assert element in [0.0, 1.0]
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
def test_loading():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment