Skip to content
Snippets Groups Projects
Commit fb46d403 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[test] Add checks for specific image shapes

parent 47239f3d
No related branches found
No related tags found
2 merge requests!18Update tests,!16Make square centre-padding a model transform
......@@ -160,9 +160,19 @@ class DatabaseCheckers:
Parameters
----------
<<<<<<< HEAD
split
An instance of DatabaseSplit.
lengths
=======
make_split
A database specific function that takes a split name and returns
the loaded database split.
split_filename
This is the split we will check.
lenghts
>>>>>>> 91bcad6 ([test] Add checks for specific image shapes)
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
......@@ -197,13 +207,13 @@ class DatabaseCheckers:
color_planes: int,
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
expected_num_labels: typing.Optional[int] = None,
expected_num_labels: int,
expected_image_shape: typing.Optional[tuple[int, ...]] = None,
):
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
......@@ -215,15 +225,24 @@ class DatabaseCheckers:
prefixes.
possible_labels
These are the list of possible labels contained in any split.
expected_num_labels
The expected number of labels each sample should have.
expected_image_shape
The expected shape of the image (num_channels, width, height).
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes # grayscale images
assert batch[0].shape[1] == color_planes
assert batch[0].shape[2] == batch[0].shape[3] # image is square
if expected_image_shape:
assert all(
[data.shape == expected_image_shape for data in batch[0]]
)
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
......
......@@ -35,22 +35,18 @@ def test_protocol_consistency(
)
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
@pytest.mark.parametrize(
"dataset",
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
testdata = [
("default", "train", 14),
("default", "validation", 14),
("default", "test", 14),
("cardiomegaly", "train", 14),
("cardiomegaly", "validation", 14),
]
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
@pytest.mark.parametrize("name,dataset,num_labels", testdata)
def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
datamodule = importlib.import_module(
f".{name}", "mednet.config.data.nih_cxr14"
).datamodule
......@@ -70,9 +66,10 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1,
prefixes=("images/000",),
possible_labels=(0, 1),
expected_num_labels=num_labels,
expected_image_shape=(1, 1024, 1024),
)
limit -= 1
# TODO: check size 1024x1024
# TODO: check there are 14 binary labels (0, 1)
......@@ -151,14 +151,16 @@ def test_protocol_consistency(
def check_loaded_batch(
batch,
batch_size: int,
color_planes: int,
prefixes: typing.Sequence[str],
expected_num_labels: typing.Optional[int] = None,
possible_labels: typing.Sequence[int],
expected_num_labels: int,
expected_image_shape: typing.Optional[tuple[int, ...]] = None,
):
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
......@@ -172,9 +174,11 @@ def check_loaded_batch(
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == 3 # grayscale images
assert batch[0].shape[1] == color_planes
assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert batch[0].shape[2] == 512 # image is 512 pixels large
if expected_image_shape:
assert all([data.shape == expected_image_shape for data in batch[0]])
assert isinstance(batch[1], dict) # metadata
assert (
......@@ -182,7 +186,7 @@ def check_loaded_batch(
) # label, name and radiological sign bounding-boxes
assert "label" in batch[1]
assert all([k in (0, 1) for k in batch[1]["label"]])
assert all([k in possible_labels for k in batch[1]["label"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels
......@@ -272,7 +276,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
prefixes=prefixes,
possible_labels=(0, 1),
expected_num_labels=1,
expected_image_shape=(3, 512, 512),
)
limit -= 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment