Skip to content
Snippets Groups Projects
Commit 47239f3d authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[test] Check number of labels per sample in DataLoader

parent 4ccefa94
No related branches found
No related tags found
2 merge requests!18Update tests,!16Make square centre-padding a model transform
...@@ -156,7 +156,7 @@ class DatabaseCheckers: ...@@ -156,7 +156,7 @@ class DatabaseCheckers:
prefixes: typing.Sequence[str], prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int], possible_labels: typing.Sequence[int],
): ):
"""Run a simple consistence check on the data split. """Run a simple consistency check on the data split.
Parameters Parameters
---------- ----------
...@@ -197,8 +197,9 @@ class DatabaseCheckers: ...@@ -197,8 +197,9 @@ class DatabaseCheckers:
color_planes: int, color_planes: int,
prefixes: typing.Sequence[str], prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int], possible_labels: typing.Sequence[int],
expected_num_labels: typing.Optional[int] = None,
): ):
"""Check the consistence of an individual (loaded) batch. """Check the consistency of an individual (loaded) batch.
Parameters Parameters
---------- ----------
...@@ -229,6 +230,9 @@ class DatabaseCheckers: ...@@ -229,6 +230,9 @@ class DatabaseCheckers:
assert "label" in batch[1] assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]]) assert all([k in possible_labels for k in batch[1]["label"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels
assert "name" in batch[1] assert "name" in batch[1]
assert all( assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
......
...@@ -87,5 +87,6 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -87,5 +87,6 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("HIV-TB_Algorithm_study_X-rays",), prefixes=("HIV-TB_Algorithm_study_X-rays",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
...@@ -92,5 +92,6 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -92,5 +92,6 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("DatasetA/Training", "DatasetA/Testing"), prefixes=("DatasetA/Training", "DatasetA/Testing"),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
...@@ -89,5 +89,6 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -89,5 +89,6 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("CXR_png/MCUCXR_0",), prefixes=("CXR_png/MCUCXR_0",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
...@@ -40,21 +40,18 @@ def test_protocol_consistency( ...@@ -40,21 +40,18 @@ def test_protocol_consistency(
) )
testdata = [
("idiap", "train", 193),
("idiap", "test", 1),
("tb_idiap", "train", 1),
("no_tb_idiap", "train", 14),
("cardiomegaly_idiap", "train", 14),
]
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest") @pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
@pytest.mark.parametrize( @pytest.mark.parametrize("name,dataset,num_labels", testdata)
"dataset", def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
["train", "test"],
)
@pytest.mark.parametrize(
"name",
[
"idiap",
"tb_idiap",
"no_tb_idiap",
"cardiomegaly_idiap",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module( datamodule = importlib.import_module(
f".{name}", "mednet.config.data.padchest" f".{name}", "mednet.config.data.padchest"
).datamodule ).datamodule
...@@ -86,10 +83,6 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -86,10 +83,6 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("",), prefixes=("",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=num_labels,
) )
limit -= 1 limit -= 1
# TODO: check size 1024x1024
# TODO: check there are 14 binary labels (0, 1) (in some cases, in others much
# more)...
...@@ -89,5 +89,6 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -89,5 +89,6 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("CXR_png/CHNCXR_0",), prefixes=("CXR_png/CHNCXR_0",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
...@@ -93,5 +93,6 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -93,5 +93,6 @@ def test_loading(database_checkers, name: str, dataset: str):
"TBPOC_CXR/tbpoc-", "TBPOC_CXR/tbpoc-",
), ),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=1,
) )
limit -= 1 limit -= 1
...@@ -152,8 +152,9 @@ def check_loaded_batch( ...@@ -152,8 +152,9 @@ def check_loaded_batch(
batch, batch,
batch_size: int, batch_size: int,
prefixes: typing.Sequence[str], prefixes: typing.Sequence[str],
expected_num_labels: typing.Optional[int] = None,
): ):
"""Check the consistence of an individual (loaded) batch. """Check the consistency of an individual (loaded) batch.
Parameters Parameters
---------- ----------
...@@ -183,6 +184,9 @@ def check_loaded_batch( ...@@ -183,6 +184,9 @@ def check_loaded_batch(
assert "label" in batch[1] assert "label" in batch[1]
assert all([k in (0, 1) for k in batch[1]["label"]]) assert all([k in (0, 1) for k in batch[1]["label"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels
assert "name" in batch[1] assert "name" in batch[1]
assert all( assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
...@@ -269,5 +273,6 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): ...@@ -269,5 +273,6 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
batch, batch,
batch_size=1, batch_size=1,
prefixes=prefixes, prefixes=prefixes,
expected_num_labels=1,
) )
limit -= 1 limit -= 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment