# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """Tests for TBX11K dataset.""" import importlib import typing import pytest import torch from click.testing import CliRunner def id_function(val): if isinstance(val, dict | tuple): return repr(val) return repr(val) @pytest.mark.parametrize( "split,lenghts,prefixes", [ ( "v1-healthy-vs-atb", dict(train=2767, validation=706, test=957), ("imgs/health", "imgs/tb"), ), ( "v1-fold-0", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-1", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-2", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-3", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-4", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-5", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-6", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-7", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-8", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v1-fold-9", dict(train=3177, validation=810, test=443), ("imgs/health", "imgs/tb"), ), ( "v2-others-vs-atb", dict(train=5241, validation=1335, test=1793), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-0", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-1", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-2", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-3", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-4", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-5", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-6", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-7", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-8", dict(train=6003, validation=1529, test=837), ("imgs/health", "imgs/sick", "imgs/tb"), ), ( "v2-fold-9", dict(train=6003, validation=1530, test=836), ("imgs/health", "imgs/sick", "imgs/tb"), ), ], ids=id_function, # just changes how pytest prints it ) def test_protocol_consistency( database_checkers, split: str, lenghts: dict[str, int], prefixes: typing.Sequence[str], ): from mednet.libs.common.data.split import make_split database_checkers.check_split( make_split("mednet.config.data.tbx11k", f"{split}.json"), lengths=lenghts, prefixes=prefixes, possible_labels=(0, 1), ) def check_loaded_batch( batch, batch_size: int, color_planes: int, prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], expected_num_labels: int, expected_image_shape: tuple[int, ...] | None = None, ): """Check the consistency of an individual (loaded) batch. Parameters ---------- batch The loaded batch to be checked. batch_size The mini-batch size. color_planes The number of color planes in the images. prefixes Each file named in a split should start with at least one of these prefixes. possible_labels These are the list of possible labels contained in any split. expected_num_labels The expected number of labels each sample should have. expected_image_shape The expected shape of the image (num_channels, width, height). """ assert len(batch) == 2 # data, metadata assert isinstance(batch[0], torch.Tensor) assert batch[0].shape[0] == batch_size # mini-batch size assert batch[0].shape[1] == color_planes assert batch[0].shape[2] == batch[0].shape[3] # image is square if expected_image_shape: assert all([data.shape == expected_image_shape for data in batch[0]]) assert isinstance(batch[1], dict) # metadata assert len(batch[1]) == 3 # label, name and radiological sign bounding-boxes assert "label" in batch[1] assert all([k in possible_labels for k in batch[1]["label"]]) if expected_num_labels: assert len(batch[1]["label"]) == expected_num_labels assert "name" in batch[1] assert all( [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]], ) assert "bounding_boxes" in batch[1] for sample, label, bboxes in zip( batch[0], batch[1]["label"], batch[1]["bounding_boxes"], ): # there must be a sign indicated on the image, if active TB is detected if label == 1: assert len(bboxes) != 0 # eif label == 0: # not true, may have TBI! # assert len(bboxes) == 0 # asserts all bounding boxes are within the raw image width and height for bbox in bboxes: if label == 1: assert bbox.label == 1 else: assert bbox.label == 0 assert bbox.xmax < sample.shape[2] assert bbox.ymax < sample.shape[1] # use the code below to view generated images # from torchvision.transforms.functional import to_pil_image # to_pil_image(batch[0][0]).show() # __import__("pdb").set_trace() @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") def test_database_check(): from mednet.libs.classification.scripts.database import check runner = CliRunner() result = runner.invoke(check, ["--limit=10", "tbx11k-v1-f0"]) assert ( result.exit_code == 0 ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" result = runner.invoke(check, ["--limit=10", "tbx11k-v2-f0"]) assert ( result.exit_code == 0 ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") @pytest.mark.parametrize( "dataset", [ "train", "validation", "test", ], ) @pytest.mark.parametrize( "name,prefixes", [ ("v1_healthy_vs_atb", ("imgs/health", "imgs/tb")), ("v1_fold_0", ("imgs/health", "imgs/tb")), ("v1_fold_1", ("imgs/health", "imgs/tb")), ("v1_fold_2", ("imgs/health", "imgs/tb")), ("v1_fold_3", ("imgs/health", "imgs/tb")), ("v1_fold_4", ("imgs/health", "imgs/tb")), ("v1_fold_5", ("imgs/health", "imgs/tb")), ("v1_fold_6", ("imgs/health", "imgs/tb")), ("v1_fold_7", ("imgs/health", "imgs/tb")), ("v1_fold_8", ("imgs/health", "imgs/tb")), ("v1_fold_9", ("imgs/health", "imgs/tb")), ("v2_others_vs_atb", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_0", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_1", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_2", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_3", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_4", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_5", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_6", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_7", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_8", ("imgs/health", "imgs/sick", "imgs/tb")), ("v2_fold_9", ("imgs/health", "imgs/sick", "imgs/tb")), ], ) def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): datamodule = importlib.import_module( f".{name}", "mednet.libs.classification.config.data.tbx11k", ).datamodule datamodule.model_transforms = [] # should be done before setup() datamodule.setup("predict") # sets up all datasets loader = datamodule.predict_dataloader()[dataset] limit = 50 # limit load checking for batch in loader: if limit == 0: break check_loaded_batch( batch, batch_size=1, color_planes=3, prefixes=prefixes, possible_labels=(0, 1), expected_num_labels=1, expected_image_shape=(3, 512, 512), ) limit -= 1 @pytest.mark.parametrize( "split", [ "v1_fold_0", "v2_fold_0", ], ) @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") def test_loaded_image_quality(database_checkers, datadir, split): reference_histogram_file = str( datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json", ) datamodule = importlib.import_module( f".{split}", "mednet.libs.classification.config.data.tbx11k", ).datamodule datamodule.model_transforms = [] datamodule.setup("predict") database_checkers.check_image_quality(datamodule, reference_histogram_file)