test_tbx11k.py 10.02 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for TBX11K dataset."""
import importlib
import typing
import pytest
import torch
from click.testing import CliRunner
def id_function(val):
if isinstance(val, dict | tuple):
return repr(val)
return repr(val)
@pytest.mark.parametrize(
"split,lenghts,prefixes",
[
(
"v1-healthy-vs-atb",
dict(train=2767, validation=706, test=957),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-0",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-1",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-2",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-3",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-4",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-5",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-6",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-7",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-8",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v1-fold-9",
dict(train=3177, validation=810, test=443),
("imgs/health", "imgs/tb"),
),
(
"v2-others-vs-atb",
dict(train=5241, validation=1335, test=1793),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-0",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-1",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-2",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-3",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-4",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-5",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-6",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-7",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-8",
dict(train=6003, validation=1529, test=837),
("imgs/health", "imgs/sick", "imgs/tb"),
),
(
"v2-fold-9",
dict(train=6003, validation=1530, test=836),
("imgs/health", "imgs/sick", "imgs/tb"),
),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers,
split: str,
lenghts: dict[str, int],
prefixes: typing.Sequence[str],
):
from mednet.libs.common.data.split import make_split
database_checkers.check_split(
make_split("mednet.config.data.tbx11k", f"{split}.json"),
lengths=lenghts,
prefixes=prefixes,
possible_labels=(0, 1),
)
def check_loaded_batch(
batch,
batch_size: int,
color_planes: int,
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
expected_num_labels: int,
expected_image_shape: tuple[int, ...] | None = None,
):
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
The mini-batch size.
color_planes
The number of color planes in the images.
prefixes
Each file named in a split should start with at least one of these
prefixes.
possible_labels
These are the list of possible labels contained in any split.
expected_num_labels
The expected number of labels each sample should have.
expected_image_shape
The expected shape of the image (num_channels, width, height).
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes
assert batch[0].shape[2] == batch[0].shape[3] # image is square
if expected_image_shape:
assert all([data.shape == expected_image_shape for data in batch[0]])
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 3 # label, name and radiological sign bounding-boxes
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels
assert "name" in batch[1]
assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]],
)
assert "bounding_boxes" in batch[1]
for sample, label, bboxes in zip(
batch[0],
batch[1]["label"],
batch[1]["bounding_boxes"],
):
# there must be a sign indicated on the image, if active TB is detected
if label == 1:
assert len(bboxes) != 0
# eif label == 0: # not true, may have TBI!
# assert len(bboxes) == 0
# asserts all bounding boxes are within the raw image width and height
for bbox in bboxes:
if label == 1:
assert bbox.label == 1
else:
assert bbox.label == 0
assert bbox.xmax < sample.shape[2]
assert bbox.ymax < sample.shape[1]
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace()
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
def test_database_check():
from mednet.libs.classification.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["--limit=10", "tbx11k-v1-f0"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
result = runner.invoke(check, ["--limit=10", "tbx11k-v2-f0"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
@pytest.mark.parametrize(
"dataset",
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name,prefixes",
[
("v1_healthy_vs_atb", ("imgs/health", "imgs/tb")),
("v1_fold_0", ("imgs/health", "imgs/tb")),
("v1_fold_1", ("imgs/health", "imgs/tb")),
("v1_fold_2", ("imgs/health", "imgs/tb")),
("v1_fold_3", ("imgs/health", "imgs/tb")),
("v1_fold_4", ("imgs/health", "imgs/tb")),
("v1_fold_5", ("imgs/health", "imgs/tb")),
("v1_fold_6", ("imgs/health", "imgs/tb")),
("v1_fold_7", ("imgs/health", "imgs/tb")),
("v1_fold_8", ("imgs/health", "imgs/tb")),
("v1_fold_9", ("imgs/health", "imgs/tb")),
("v2_others_vs_atb", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_0", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_1", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_2", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_3", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_4", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_5", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_6", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_7", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_8", ("imgs/health", "imgs/sick", "imgs/tb")),
("v2_fold_9", ("imgs/health", "imgs/sick", "imgs/tb")),
],
)
def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
datamodule = importlib.import_module(
f".{name}",
"mednet.libs.classification.config.data.tbx11k",
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 50 # limit load checking
for batch in loader:
if limit == 0:
break
check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
prefixes=prefixes,
possible_labels=(0, 1),
expected_num_labels=1,
expected_image_shape=(3, 512, 512),
)
limit -= 1
@pytest.mark.parametrize(
"split",
[
"v1_fold_0",
"v2_fold_0",
],
)
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
def test_loaded_image_quality(database_checkers, datadir, split):
reference_histogram_file = str(
datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json",
)
datamodule = importlib.import_module(
f".{split}",
"mednet.libs.classification.config.data.tbx11k",
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)