Skip to content
Snippets Groups Projects
test_tbx11k.py 10.02 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for TBX11K dataset."""

import importlib
import typing

import pytest
import torch
from click.testing import CliRunner


def id_function(val):
    if isinstance(val, dict | tuple):
        return repr(val)
    return repr(val)


@pytest.mark.parametrize(
    "split,lenghts,prefixes",
    [
        (
            "v1-healthy-vs-atb",
            dict(train=2767, validation=706, test=957),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-0",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-1",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-2",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-3",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-4",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-5",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-6",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-7",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-8",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v1-fold-9",
            dict(train=3177, validation=810, test=443),
            ("imgs/health", "imgs/tb"),
        ),
        (
            "v2-others-vs-atb",
            dict(train=5241, validation=1335, test=1793),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-0",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-1",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-2",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-3",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-4",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-5",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-6",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-7",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-8",
            dict(train=6003, validation=1529, test=837),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
        (
            "v2-fold-9",
            dict(train=6003, validation=1530, test=836),
            ("imgs/health", "imgs/sick", "imgs/tb"),
        ),
    ],
    ids=id_function,  # just changes how pytest prints it
)
def test_protocol_consistency(
    database_checkers,
    split: str,
    lenghts: dict[str, int],
    prefixes: typing.Sequence[str],
):
    from mednet.libs.common.data.split import make_split

    database_checkers.check_split(
        make_split("mednet.config.data.tbx11k", f"{split}.json"),
        lengths=lenghts,
        prefixes=prefixes,
        possible_labels=(0, 1),
    )


def check_loaded_batch(
    batch,
    batch_size: int,
    color_planes: int,
    prefixes: typing.Sequence[str],
    possible_labels: typing.Sequence[int],
    expected_num_labels: int,
    expected_image_shape: tuple[int, ...] | None = None,
):
    """Check the consistency of an individual (loaded) batch.

    Parameters
    ----------
    batch
        The loaded batch to be checked.
    batch_size
        The mini-batch size.
    color_planes
        The number of color planes in the images.
    prefixes
        Each file named in a split should start with at least one of these
        prefixes.
    possible_labels
        These are the list of possible labels contained in any split.
    expected_num_labels
        The expected number of labels each sample should have.
    expected_image_shape
        The expected shape of the image (num_channels, width, height).
    """

    assert len(batch) == 2  # data, metadata

    assert isinstance(batch[0], torch.Tensor)
    assert batch[0].shape[0] == batch_size  # mini-batch size
    assert batch[0].shape[1] == color_planes
    assert batch[0].shape[2] == batch[0].shape[3]  # image is square

    if expected_image_shape:
        assert all([data.shape == expected_image_shape for data in batch[0]])

    assert isinstance(batch[1], dict)  # metadata
    assert len(batch[1]) == 3  # label, name and radiological sign bounding-boxes

    assert "label" in batch[1]
    assert all([k in possible_labels for k in batch[1]["label"]])

    if expected_num_labels:
        assert len(batch[1]["label"]) == expected_num_labels

    assert "name" in batch[1]
    assert all(
        [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]],
    )

    assert "bounding_boxes" in batch[1]

    for sample, label, bboxes in zip(
        batch[0],
        batch[1]["label"],
        batch[1]["bounding_boxes"],
    ):
        # there must be a sign indicated on the image, if active TB is detected
        if label == 1:
            assert len(bboxes) != 0

        # eif label == 0:  # not true, may have TBI!
        #    assert len(bboxes) == 0

        # asserts all bounding boxes are within the raw image width and height
        for bbox in bboxes:
            if label == 1:
                assert bbox.label == 1
            else:
                assert bbox.label == 0
            assert bbox.xmax < sample.shape[2]
            assert bbox.ymax < sample.shape[1]

    # use the code below to view generated images
    # from torchvision.transforms.functional import to_pil_image
    # to_pil_image(batch[0][0]).show()
    # __import__("pdb").set_trace()


@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
def test_database_check():
    from mednet.libs.classification.scripts.database import check

    runner = CliRunner()
    result = runner.invoke(check, ["--limit=10", "tbx11k-v1-f0"])
    assert (
        result.exit_code == 0
    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"

    result = runner.invoke(check, ["--limit=10", "tbx11k-v2-f0"])
    assert (
        result.exit_code == 0
    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"


@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
@pytest.mark.parametrize(
    "dataset",
    [
        "train",
        "validation",
        "test",
    ],
)
@pytest.mark.parametrize(
    "name,prefixes",
    [
        ("v1_healthy_vs_atb", ("imgs/health", "imgs/tb")),
        ("v1_fold_0", ("imgs/health", "imgs/tb")),
        ("v1_fold_1", ("imgs/health", "imgs/tb")),
        ("v1_fold_2", ("imgs/health", "imgs/tb")),
        ("v1_fold_3", ("imgs/health", "imgs/tb")),
        ("v1_fold_4", ("imgs/health", "imgs/tb")),
        ("v1_fold_5", ("imgs/health", "imgs/tb")),
        ("v1_fold_6", ("imgs/health", "imgs/tb")),
        ("v1_fold_7", ("imgs/health", "imgs/tb")),
        ("v1_fold_8", ("imgs/health", "imgs/tb")),
        ("v1_fold_9", ("imgs/health", "imgs/tb")),
        ("v2_others_vs_atb", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_0", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_1", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_2", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_3", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_4", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_5", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_6", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_7", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_8", ("imgs/health", "imgs/sick", "imgs/tb")),
        ("v2_fold_9", ("imgs/health", "imgs/sick", "imgs/tb")),
    ],
)
def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
    datamodule = importlib.import_module(
        f".{name}",
        "mednet.libs.classification.config.data.tbx11k",
    ).datamodule

    datamodule.model_transforms = []  # should be done before setup()
    datamodule.setup("predict")  # sets up all datasets

    loader = datamodule.predict_dataloader()[dataset]

    limit = 50  # limit load checking
    for batch in loader:
        if limit == 0:
            break
        check_loaded_batch(
            batch,
            batch_size=1,
            color_planes=3,
            prefixes=prefixes,
            possible_labels=(0, 1),
            expected_num_labels=1,
            expected_image_shape=(3, 512, 512),
        )
        limit -= 1


@pytest.mark.parametrize(
    "split",
    [
        "v1_fold_0",
        "v2_fold_0",
    ],
)
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
def test_loaded_image_quality(database_checkers, datadir, split):
    reference_histogram_file = str(
        datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json",
    )

    datamodule = importlib.import_module(
        f".{split}",
        "mednet.libs.classification.config.data.tbx11k",
    ).datamodule

    datamodule.model_transforms = []
    datamodule.setup("predict")

    database_checkers.check_image_quality(datamodule, reference_histogram_file)