# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """Tests for PadChest dataset.""" import importlib import pytest from click.testing import CliRunner def id_function(val): if isinstance(val, dict): return str(val) return repr(val) @pytest.mark.parametrize( "split,lenghts", [ # ("idiap.json.bz2", dict(train=96269)), ## many labels ("tb-idiap.json", dict(train=200, test=50)), # 0: no-tb, 1: tb ( "no-tb-idiap.json.bz2", dict(train=54371, validation=4052), ), # 14 labels ("cardiomegaly-idiap.json", dict(train=40)), # 14 labels ], ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( database_checkers, split: str, lenghts: dict[str, int] ): from mednet.config.data.padchest.datamodule import make_split database_checkers.check_split( make_split(split), lengths=lenghts, prefixes=("",), possible_labels=(0, 1), ) @pytest.mark.skip_if_rc_var_not_set("datadir.padchest") def test_database_check(): from mednet.scripts.database import check runner = CliRunner() result = runner.invoke(check, ["--limit=10", "padchest-idiap"]) assert ( result.exit_code == 0 ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" testdata = [ ("idiap", "train", 193), ("idiap", "test", 1), ("tb_idiap", "train", 1), ("no_tb_idiap", "train", 14), ("cardiomegaly_idiap", "train", 14), ] @pytest.mark.skip_if_rc_var_not_set("datadir.padchest") @pytest.mark.parametrize("name,dataset,num_labels", testdata) def test_loading(database_checkers, name: str, dataset: str, num_labels: int): datamodule = importlib.import_module( f".{name}", "mednet.config.data.padchest" ).datamodule datamodule.model_transforms = [] # should be done before setup() datamodule.setup("predict") # sets up all datasets if dataset in datamodule.predict_dataloader(): loader = datamodule.predict_dataloader()[dataset] limit = 3 # limit load checking for batch in loader: if limit == 0: break database_checkers.check_loaded_batch( batch, batch_size=1, color_planes=1, prefixes=("",), possible_labels=(0, 1), expected_num_labels=num_labels, ) limit -= 1 @pytest.mark.skip_if_rc_var_not_set("datadir.padchest") def test_loaded_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_padchest_idiap.json" ) datamodule = importlib.import_module( ".idiap", "mednet.config.data.padchest" ).datamodule datamodule.model_transforms = [] datamodule.setup("predict") database_checkers.check_image_quality(datamodule, reference_histogram_file)