From 47239f3d4f3a1408772d4547e4a291d595246840 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 31 Jan 2024 12:44:16 +0100 Subject: [PATCH] [test] Check number of labels per sample in DataLoader --- tests/conftest.py | 8 ++++++-- tests/test_hivtb.py | 1 + tests/test_indian.py | 1 + tests/test_montgomery.py | 1 + tests/test_padchest.py | 31 ++++++++++++------------------- tests/test_shenzhen.py | 1 + tests/test_tbpoc.py | 1 + tests/test_tbx11k.py | 7 ++++++- 8 files changed, 29 insertions(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b4f92331..febcc24f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,7 +156,7 @@ class DatabaseCheckers: prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], ): - """Run a simple consistence check on the data split. + """Run a simple consistency check on the data split. Parameters ---------- @@ -197,8 +197,9 @@ class DatabaseCheckers: color_planes: int, prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], + expected_num_labels: typing.Optional[int] = None, ): - """Check the consistence of an individual (loaded) batch. + """Check the consistency of an individual (loaded) batch. Parameters ---------- @@ -229,6 +230,9 @@ class DatabaseCheckers: assert "label" in batch[1] assert all([k in possible_labels for k in batch[1]["label"]]) + if expected_num_labels: + assert len(batch[1]["label"]) == expected_num_labels + assert "name" in batch[1] assert all( [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index 00b7a484..03066c2f 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -87,5 +87,6 @@ def test_loading(database_checkers, name: str, dataset: str): color_planes=1, prefixes=("HIV-TB_Algorithm_study_X-rays",), possible_labels=(0, 1), + expected_num_labels=1, ) limit -= 1 diff --git a/tests/test_indian.py b/tests/test_indian.py index 66ffdef3..8dd07159 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -92,5 +92,6 @@ def test_loading(database_checkers, name: str, dataset: str): color_planes=1, prefixes=("DatasetA/Training", "DatasetA/Testing"), possible_labels=(0, 1), + expected_num_labels=1, ) limit -= 1 diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 6933e982..91f8cc10 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -89,5 +89,6 @@ def test_loading(database_checkers, name: str, dataset: str): color_planes=1, prefixes=("CXR_png/MCUCXR_0",), possible_labels=(0, 1), + expected_num_labels=1, ) limit -= 1 diff --git a/tests/test_padchest.py b/tests/test_padchest.py index 97bf8a9c..5fc342e1 100644 --- a/tests/test_padchest.py +++ b/tests/test_padchest.py @@ -40,21 +40,18 @@ def test_protocol_consistency( ) +testdata = [ + ("idiap", "train", 193), + ("idiap", "test", 1), + ("tb_idiap", "train", 1), + ("no_tb_idiap", "train", 14), + ("cardiomegaly_idiap", "train", 14), +] + + @pytest.mark.skip_if_rc_var_not_set("datadir.padchest") -@pytest.mark.parametrize( - "dataset", - ["train", "test"], -) -@pytest.mark.parametrize( - "name", - [ - "idiap", - "tb_idiap", - "no_tb_idiap", - "cardiomegaly_idiap", - ], -) -def test_loading(database_checkers, name: str, dataset: str): +@pytest.mark.parametrize("name,dataset,num_labels", testdata) +def test_loading(database_checkers, name: str, dataset: str, num_labels: int): datamodule = importlib.import_module( f".{name}", "mednet.config.data.padchest" ).datamodule @@ -86,10 +83,6 @@ def test_loading(database_checkers, name: str, dataset: str): color_planes=1, prefixes=("",), possible_labels=(0, 1), + expected_num_labels=num_labels, ) limit -= 1 - - -# TODO: check size 1024x1024 -# TODO: check there are 14 binary labels (0, 1) (in some cases, in others much -# more)... diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index 6658831b..aeb9da85 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -89,5 +89,6 @@ def test_loading(database_checkers, name: str, dataset: str): color_planes=1, prefixes=("CXR_png/CHNCXR_0",), possible_labels=(0, 1), + expected_num_labels=1, ) limit -= 1 diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py index a2704fa7..58f762d5 100644 --- a/tests/test_tbpoc.py +++ b/tests/test_tbpoc.py @@ -93,5 +93,6 @@ def test_loading(database_checkers, name: str, dataset: str): "TBPOC_CXR/tbpoc-", ), possible_labels=(0, 1), + expected_num_labels=1, ) limit -= 1 diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index b1bbcc1f..231982e6 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -152,8 +152,9 @@ def check_loaded_batch( batch, batch_size: int, prefixes: typing.Sequence[str], + expected_num_labels: typing.Optional[int] = None, ): - """Check the consistence of an individual (loaded) batch. + """Check the consistency of an individual (loaded) batch. Parameters ---------- @@ -183,6 +184,9 @@ def check_loaded_batch( assert "label" in batch[1] assert all([k in (0, 1) for k in batch[1]["label"]]) + if expected_num_labels: + assert len(batch[1]["label"]) == expected_num_labels + assert "name" in batch[1] assert all( [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] @@ -269,5 +273,6 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): batch, batch_size=1, prefixes=prefixes, + expected_num_labels=1, ) limit -= 1 -- GitLab