-
Daniel CARRON authoredDaniel CARRON authored
test_padchest.py 2.13 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for PadChest dataset."""
import importlib
import pytest
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lenghts",
[
# ("idiap.json.bz2", dict(train=96269)), ## many labels
("tb-idiap.json", dict(train=200, test=50)), # 0: no-tb, 1: tb
(
"no-tb-idiap.json.bz2",
dict(train=54371, validation=4052),
), # 14 labels
("cardiomegaly-idiap.json", dict(train=40)), # 14 labels
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers, split: str, lenghts: dict[str, int]
):
from mednet.config.data.padchest.datamodule import make_split
database_checkers.check_split(
make_split(split),
lengths=lenghts,
prefixes=("",),
possible_labels=(0, 1),
)
testdata = [
("idiap", "train", 193),
("idiap", "test", 1),
("tb_idiap", "train", 1),
("no_tb_idiap", "train", 14),
("cardiomegaly_idiap", "train", 14),
]
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
@pytest.mark.parametrize("name,dataset,num_labels", testdata)
def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
datamodule = importlib.import_module(
f".{name}", "mednet.config.data.padchest"
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
if dataset in datamodule.predict_dataloader():
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=1,
prefixes=("",),
possible_labels=(0, 1),
expected_num_labels=num_labels,
)
limit -= 1