Skip to content
Snippets Groups Projects
Commit 21fc141d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tbx11k] Loads RS bounding-boxes with sample; Add tests for bounding-boxes

parent 0203f7c6
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #78056 failed
......@@ -87,7 +87,11 @@ class RawDataLoader(_BaseRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type]
return tensor, dict(
label=sample[1],
name=sample[0],
radsign_bboxes=self.bbox_annotations(sample),
)
def label(self, sample: DatabaseSample) -> int:
"""Loads a single image sample label from the disk.
......
......@@ -7,6 +7,7 @@ import importlib
import typing
import pytest
import torch
def id_function(val):
......@@ -147,6 +148,71 @@ def test_protocol_consistency(
)
def check_loaded_batch(
batch,
batch_size: int,
prefixes: typing.Sequence[str],
):
"""Checks the consistence of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
size
The mini-batch size
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == 3 # grayscale images
assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert batch[0].shape[2] == 512 # image is 512 pixels large
assert isinstance(batch[1], dict) # metadata
assert (
len(batch[1]) == 3
) # label, name and radiological sign bounding-boxes
assert "label" in batch[1]
assert all([k in (0, 1) for k in batch[1]["label"]])
assert "name" in batch[1]
assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
)
assert "radsign_bboxes" in batch[1]
for sample, label, bboxes in zip(
batch[0], batch[1]["label"], batch[1]["radsign_bboxes"]
):
# there must be a sign indicated on the image, if active TB is detected
if label == 1:
assert len(bboxes[0]) != 0
# eif label == 0: # not true, may have TBI!
# assert len(bboxes) == 0
# asserts all bounding boxes are within the raw image width and height
for bbox_label, xmin, ymin, width, height in zip(*bboxes):
if label == 1:
assert bbox_label == 1
else:
assert bbox_label == 0
assert (xmin + width) < sample.shape[2]
assert (ymin + height) < sample.shape[1]
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace()
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
@pytest.mark.parametrize(
"dataset",
......@@ -183,9 +249,7 @@ def test_protocol_consistency(
("v2_fold_9", ("imgs/health", "imgs/sick", "imgs/tb")),
],
)
def test_loading(
database_checkers, name: str, dataset: str, prefixes: typing.Sequence[str]
):
def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
datamodule = importlib.import_module(
f".{name}", "ptbench.config.data.tbx11k"
).datamodule
......@@ -195,21 +259,13 @@ def test_loading(
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
limit = 50 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
prefixes=prefixes,
possible_labels=(0, 1),
)
limit -= 1
# TODO: Tests for loading bounding boxes:
# if patient has active tb, then has to have 1 or more bounding boxes
# if patient does not have active tb, there should be no bounding boxes
# bounding boxes must be within image (512 x 512 pixels)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment