diff --git a/src/ptbench/config/data/tbx11k/datamodule.py b/src/ptbench/config/data/tbx11k/datamodule.py index a5959e820b18b9a04d53f15b88e6001b8fdaa7b3..37ac5546fa865c16866940c3eada2641101e4c95 100644 --- a/src/ptbench/config/data/tbx11k/datamodule.py +++ b/src/ptbench/config/data/tbx11k/datamodule.py @@ -87,7 +87,11 @@ class RawDataLoader(_BaseRawDataLoader): # to_pil_image(tensor).show() # __import__("pdb").set_trace() - return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] + return tensor, dict( + label=sample[1], + name=sample[0], + radsign_bboxes=self.bbox_annotations(sample), + ) def label(self, sample: DatabaseSample) -> int: """Loads a single image sample label from the disk. diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 0c44a85a3e31d36f5764cd46acb6e0a987b9854f..1b3a72225522420d8089a54108dcd978ab1a2c44 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -7,6 +7,7 @@ import importlib import typing import pytest +import torch def id_function(val): @@ -147,6 +148,71 @@ def test_protocol_consistency( ) +def check_loaded_batch( + batch, + batch_size: int, + prefixes: typing.Sequence[str], +): + """Checks the consistence of an individual (loaded) batch. + + Parameters + ---------- + + batch + The loaded batch to be checked. + + size + The mini-batch size + """ + + assert len(batch) == 2 # data, metadata + + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape[0] == batch_size # mini-batch size + assert batch[0].shape[1] == 3 # grayscale images + assert batch[0].shape[2] == batch[0].shape[3] # image is square + assert batch[0].shape[2] == 512 # image is 512 pixels large + + assert isinstance(batch[1], dict) # metadata + assert ( + len(batch[1]) == 3 + ) # label, name and radiological sign bounding-boxes + + assert "label" in batch[1] + assert all([k in (0, 1) for k in batch[1]["label"]]) + + assert "name" in batch[1] + assert all( + [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] + ) + + assert "radsign_bboxes" in batch[1] + + for sample, label, bboxes in zip( + batch[0], batch[1]["label"], batch[1]["radsign_bboxes"] + ): + # there must be a sign indicated on the image, if active TB is detected + if label == 1: + assert len(bboxes[0]) != 0 + + # eif label == 0: # not true, may have TBI! + # assert len(bboxes) == 0 + + # asserts all bounding boxes are within the raw image width and height + for bbox_label, xmin, ymin, width, height in zip(*bboxes): + if label == 1: + assert bbox_label == 1 + else: + assert bbox_label == 0 + assert (xmin + width) < sample.shape[2] + assert (ymin + height) < sample.shape[1] + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(batch[0][0]).show() + # __import__("pdb").set_trace() + + @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") @pytest.mark.parametrize( "dataset", @@ -183,9 +249,7 @@ def test_protocol_consistency( ("v2_fold_9", ("imgs/health", "imgs/sick", "imgs/tb")), ], ) -def test_loading( - database_checkers, name: str, dataset: str, prefixes: typing.Sequence[str] -): +def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): datamodule = importlib.import_module( f".{name}", "ptbench.config.data.tbx11k" ).datamodule @@ -195,21 +259,13 @@ def test_loading( loader = datamodule.predict_dataloader()[dataset] - limit = 3 # limit load checking + limit = 50 # limit load checking for batch in loader: if limit == 0: break - database_checkers.check_loaded_batch( + check_loaded_batch( batch, batch_size=1, - color_planes=3, prefixes=prefixes, - possible_labels=(0, 1), ) limit -= 1 - - -# TODO: Tests for loading bounding boxes: -# if patient has active tb, then has to have 1 or more bounding boxes -# if patient does not have active tb, there should be no bounding boxes -# bounding boxes must be within image (512 x 512 pixels)