Skip to content
Snippets Groups Projects
Commit fb46d403 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[test] Add checks for specific image shapes

parent 47239f3d
No related branches found
No related tags found
2 merge requests!18Update tests,!16Make square centre-padding a model transform
...@@ -160,9 +160,19 @@ class DatabaseCheckers: ...@@ -160,9 +160,19 @@ class DatabaseCheckers:
Parameters Parameters
---------- ----------
<<<<<<< HEAD
split split
An instance of DatabaseSplit. An instance of DatabaseSplit.
lengths lengths
=======
make_split
A database specific function that takes a split name and returns
the loaded database split.
split_filename
This is the split we will check.
lenghts
>>>>>>> 91bcad6 ([test] Add checks for specific image shapes)
A dictionary that contains keys matching those of the split (this will A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split. sizes of each of the datasets in the split.
...@@ -197,13 +207,13 @@ class DatabaseCheckers: ...@@ -197,13 +207,13 @@ class DatabaseCheckers:
color_planes: int, color_planes: int,
prefixes: typing.Sequence[str], prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int], possible_labels: typing.Sequence[int],
expected_num_labels: typing.Optional[int] = None, expected_num_labels: int,
expected_image_shape: typing.Optional[tuple[int, ...]] = None,
): ):
"""Check the consistency of an individual (loaded) batch. """Check the consistency of an individual (loaded) batch.
Parameters Parameters
---------- ----------
batch batch
The loaded batch to be checked. The loaded batch to be checked.
batch_size batch_size
...@@ -215,15 +225,24 @@ class DatabaseCheckers: ...@@ -215,15 +225,24 @@ class DatabaseCheckers:
prefixes. prefixes.
possible_labels possible_labels
These are the list of possible labels contained in any split. These are the list of possible labels contained in any split.
expected_num_labels
The expected number of labels each sample should have.
expected_image_shape
The expected shape of the image (num_channels, width, height).
""" """
assert len(batch) == 2 # data, metadata assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor) assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes # grayscale images assert batch[0].shape[1] == color_planes
assert batch[0].shape[2] == batch[0].shape[3] # image is square assert batch[0].shape[2] == batch[0].shape[3] # image is square
if expected_image_shape:
assert all(
[data.shape == expected_image_shape for data in batch[0]]
)
assert isinstance(batch[1], dict) # metadata assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name assert len(batch[1]) == 2 # label and name
......
...@@ -35,22 +35,18 @@ def test_protocol_consistency( ...@@ -35,22 +35,18 @@ def test_protocol_consistency(
) )
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") testdata = [
@pytest.mark.parametrize( ("default", "train", 14),
"dataset", ("default", "validation", 14),
[ ("default", "test", 14),
"train", ("cardiomegaly", "train", 14),
"validation", ("cardiomegaly", "validation", 14),
"test", ]
],
)
@pytest.mark.parametrize( @pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
"name", @pytest.mark.parametrize("name,dataset,num_labels", testdata)
[ def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module( datamodule = importlib.import_module(
f".{name}", "mednet.config.data.nih_cxr14" f".{name}", "mednet.config.data.nih_cxr14"
).datamodule ).datamodule
...@@ -70,9 +66,10 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -70,9 +66,10 @@ def test_loading(database_checkers, name: str, dataset: str):
color_planes=1, color_planes=1,
prefixes=("images/000",), prefixes=("images/000",),
possible_labels=(0, 1), possible_labels=(0, 1),
expected_num_labels=num_labels,
expected_image_shape=(1, 1024, 1024),
) )
limit -= 1 limit -= 1
# TODO: check size 1024x1024 # TODO: check size 1024x1024
# TODO: check there are 14 binary labels (0, 1)
...@@ -151,14 +151,16 @@ def test_protocol_consistency( ...@@ -151,14 +151,16 @@ def test_protocol_consistency(
def check_loaded_batch( def check_loaded_batch(
batch, batch,
batch_size: int, batch_size: int,
color_planes: int,
prefixes: typing.Sequence[str], prefixes: typing.Sequence[str],
expected_num_labels: typing.Optional[int] = None, possible_labels: typing.Sequence[int],
expected_num_labels: int,
expected_image_shape: typing.Optional[tuple[int, ...]] = None,
): ):
"""Check the consistency of an individual (loaded) batch. """Check the consistency of an individual (loaded) batch.
Parameters Parameters
---------- ----------
batch batch
The loaded batch to be checked. The loaded batch to be checked.
batch_size batch_size
...@@ -172,9 +174,11 @@ def check_loaded_batch( ...@@ -172,9 +174,11 @@ def check_loaded_batch(
assert isinstance(batch[0], torch.Tensor) assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == 3 # grayscale images assert batch[0].shape[1] == color_planes
assert batch[0].shape[2] == batch[0].shape[3] # image is square assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert batch[0].shape[2] == 512 # image is 512 pixels large
if expected_image_shape:
assert all([data.shape == expected_image_shape for data in batch[0]])
assert isinstance(batch[1], dict) # metadata assert isinstance(batch[1], dict) # metadata
assert ( assert (
...@@ -182,7 +186,7 @@ def check_loaded_batch( ...@@ -182,7 +186,7 @@ def check_loaded_batch(
) # label, name and radiological sign bounding-boxes ) # label, name and radiological sign bounding-boxes
assert "label" in batch[1] assert "label" in batch[1]
assert all([k in (0, 1) for k in batch[1]["label"]]) assert all([k in possible_labels for k in batch[1]["label"]])
if expected_num_labels: if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels assert len(batch[1]["label"]) == expected_num_labels
...@@ -272,7 +276,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]): ...@@ -272,7 +276,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
check_loaded_batch( check_loaded_batch(
batch, batch,
batch_size=1, batch_size=1,
color_planes=3,
prefixes=prefixes, prefixes=prefixes,
possible_labels=(0, 1),
expected_num_labels=1, expected_num_labels=1,
expected_image_shape=(3, 512, 512),
) )
limit -= 1 limit -= 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment