# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """Tests for Montgomery dataset.""" import importlib import json import pytest from PIL import Image def id_function(val): if isinstance(val, dict): return str(val) return repr(val) @pytest.mark.parametrize( "split,lenghts", [ ("default", dict(train=88, validation=22, test=28)), ("fold-0", dict(train=99, validation=25, test=14)), ("fold-1", dict(train=99, validation=25, test=14)), ("fold-2", dict(train=99, validation=25, test=14)), ("fold-3", dict(train=99, validation=25, test=14)), ("fold-4", dict(train=99, validation=25, test=14)), ("fold-5", dict(train=99, validation=25, test=14)), ("fold-6", dict(train=99, validation=25, test=14)), ("fold-7", dict(train=99, validation=25, test=14)), ("fold-8", dict(train=100, validation=25, test=13)), ("fold-9", dict(train=100, validation=25, test=13)), ], ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( database_checkers, split: str, lenghts: dict[str, int] ): from mednet.config.data.montgomery.datamodule import make_split database_checkers.check_split( make_split(f"{split}.json"), lengths=lenghts, prefixes=("CXR_png/MCUCXR_0",), possible_labels=(0, 1), ) @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.parametrize( "dataset", [ "train", "validation", "test", ], ) @pytest.mark.parametrize( "name", [ "default", "fold_0", "fold_1", "fold_2", "fold_3", "fold_4", "fold_5", "fold_6", "fold_7", "fold_8", "fold_9", ], ) def test_loading(database_checkers, name: str, dataset: str): datamodule = importlib.import_module( f".{name}", "mednet.config.data.montgomery" ).datamodule datamodule.model_transforms = [] # should be done before setup() datamodule.setup("predict") # sets up all datasets loader = datamodule.predict_dataloader()[dataset] limit = 3 # limit load checking for batch in loader: if limit == 0: break database_checkers.check_loaded_batch( batch, batch_size=1, color_planes=1, prefixes=("CXR_png/MCUCXR_0",), possible_labels=(0, 1), expected_num_labels=1, ) limit -= 1 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_loaded_image_quality(datadir): datamodule = importlib.import_module( ".default", "mednet.config.data.montgomery" ).datamodule datamodule.model_transforms = [] datamodule.setup("predict") loader = datamodule.splits["train"][0][1] first_sample = datamodule.splits["train"][0][0][0] image_data = loader.sample(first_sample)[0].numpy()[ 0, :, : ] # PIL expects grayscale to not have any leading dim img = Image.fromarray(image_data, mode="L") histogram = img.histogram() reference_histogram_file = str(datadir / "histogram_montgomery.json") with open(reference_histogram_file) as i_f: ref_histogram = json.load(i_f)["histogram"] assert histogram == ref_histogram