-
Daniel CARRON authoredDaniel CARRON authored
test_shenzhen.py 3.00 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for Shenzhen dataset."""
import importlib
import pytest
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lenghts",
[
("default", dict(train=422, validation=107, test=133)),
("fold-0", dict(train=476, validation=119, test=67)),
("fold-1", dict(train=476, validation=119, test=67)),
("fold-2", dict(train=476, validation=120, test=66)),
("fold-3", dict(train=476, validation=120, test=66)),
("fold-4", dict(train=476, validation=120, test=66)),
("fold-5", dict(train=476, validation=120, test=66)),
("fold-6", dict(train=476, validation=120, test=66)),
("fold-7", dict(train=476, validation=120, test=66)),
("fold-8", dict(train=476, validation=120, test=66)),
("fold-9", dict(train=476, validation=120, test=66)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers, split: str, lenghts: dict[str, int]
):
from mednet.config.data.shenzhen.datamodule import make_split
database_checkers.check_split(
make_split(f"{split}.json"),
lengths=lenghts,
prefixes=("CXR_png/CHNCXR_0",),
possible_labels=(0, 1),
)
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
@pytest.mark.parametrize(
"dataset",
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
"fold_0",
"fold_1",
"fold_2",
"fold_3",
"fold_4",
"fold_5",
"fold_6",
"fold_7",
"fold_8",
"fold_9",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module(
f".{name}", "mednet.config.data.shenzhen"
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=1,
prefixes=("CXR_png/CHNCXR_0",),
possible_labels=(0, 1),
expected_num_labels=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
def test_loaded_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_shenzhen_default.json"
)
datamodule = importlib.import_module(
".default", "mednet.config.data.shenzhen"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)