Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • medai/software/mednet
1 result
Show changes
......@@ -151,29 +151,42 @@ def test_protocol_consistency(
def check_loaded_batch(
batch,
batch_size: int,
color_planes: int,
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
expected_num_labels: int,
expected_image_shape: typing.Optional[tuple[int, ...]] = None,
):
"""Check the consistence of an individual (loaded) batch.
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
The mini-batch size.
color_planes
The number of color planes in the images.
prefixes
Each file named in a split should start with at least one of these
prefixes.
possible_labels
These are the list of possible labels contained in any split.
expected_num_labels
The expected number of labels each sample should have.
expected_image_shape
The expected shape of the image (num_channels, width, height).
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == 3 # grayscale images
assert batch[0].shape[1] == color_planes
assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert batch[0].shape[2] == 512 # image is 512 pixels large
if expected_image_shape:
assert all([data.shape == expected_image_shape for data in batch[0]])
assert isinstance(batch[1], dict) # metadata
assert (
......@@ -181,7 +194,10 @@ def check_loaded_batch(
) # label, name and radiological sign bounding-boxes
assert "label" in batch[1]
assert all([k in (0, 1) for k in batch[1]["label"]])
assert all([k in possible_labels for k in batch[1]["label"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels
assert "name" in batch[1]
assert all(
......@@ -268,6 +284,33 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
prefixes=prefixes,
possible_labels=(0, 1),
expected_num_labels=1,
expected_image_shape=(3, 512, 512),
)
limit -= 1
@pytest.mark.parametrize(
"split",
[
"v1_fold_0",
"v2_fold_0",
],
)
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
def test_loaded_image_quality(database_checkers, datadir, split):
reference_histogram_file = str(
datadir / f"histograms/raw_data/histograms_tbx11k_{split}.json"
)
datamodule = importlib.import_module(
f".{split}", "mednet.config.data.tbx11k"
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)