diff --git a/tests/conftest.py b/tests/conftest.py index 98d102e7a4721bb9c87531049956f0d195645597..fad88f4651927d047d4beb01348d6f8a9eef1f38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -185,30 +185,65 @@ class DatabaseCheckers: assert len(batch) == 2 # data, metadata - assert isinstance(batch[0], torch.Tensor) - assert batch[0].shape[0] == batch_size # mini-batch size - assert batch[0].shape[1] == color_planes - - if expected_image_shape: - assert all( - [data.shape == expected_image_shape for data in batch[0]], - ) - - assert isinstance(batch[1], dict) # metadata - assert len(batch[1]) == expected_meta_size - - assert "target" in batch[1] - if possible_labels: - assert all([k in possible_labels for k in batch[1]["target"]]) - - if expected_num_labels: - assert len(batch[1]["target"]) == expected_num_labels - - assert "name" in batch[1] - if prefixes: - assert all( - [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]], - ) + if isinstance(batch[0], dict): + assert isinstance(batch[0]["image"], torch.Tensor) + + assert batch[0]["image"].shape[0] == batch_size # mini-batch size + assert batch[0]["image"].shape[1] == color_planes + + if expected_image_shape: + assert all( + [data.shape == expected_image_shape for data in batch[0]["image"]], + ) + + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == expected_meta_size + + assert "target" in batch[0] + if possible_labels: + assert all([k in possible_labels for k in batch[0]["target"]]) + + if expected_num_labels: + assert len(batch[0]["target"]) == expected_num_labels + + assert "name" in batch[1] + if prefixes: + assert all( + [ + any([k.startswith(j) for j in prefixes]) + for k in batch[1]["name"] + ], + ) + + else: + assert isinstance(batch[0], torch.Tensor) + + assert batch[0].shape[0] == batch_size # mini-batch size + assert batch[0].shape[1] == color_planes + + if expected_image_shape: + assert all( + [data.shape == expected_image_shape for data in batch[0]], + ) + + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == expected_meta_size + + assert "target" in batch[1] + if possible_labels: + assert all([k in possible_labels for k in batch[1]["target"]]) + + if expected_num_labels: + assert len(batch[1]["target"]) == expected_num_labels + + assert "name" in batch[1] + if prefixes: + assert all( + [ + any([k.startswith(j) for j in prefixes]) + for k in batch[1]["name"] + ], + ) # use the code below to view generated images # from torchvision.transforms.functional import to_pil_image @@ -243,6 +278,9 @@ class DatabaseCheckers: dataset_sample_index ][0] + if isinstance(image_tensor, dict): + image_tensor = image_tensor["image"] + histogram = [] for color_channel in image_tensor: color_channel = numpy.multiply( diff --git a/tests/segmentation/test_chasedb1.py b/tests/segmentation/test_chasedb1.py index 04983002ab7fe205e73de4badfaceedaf3a4e96b..c9bdd8984bbd5635e26ea1aaa03afe6380bee1b3 100644 --- a/tests/segmentation/test_chasedb1.py +++ b/tests/segmentation/test_chasedb1.py @@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["Image_"], possible_labels=[], ) diff --git a/tests/segmentation/test_cxr8.py b/tests/segmentation/test_cxr8.py index eb6de32eb7f1c73c5bc171c7baf2a6961bd62db0..756afb5059012d836cf9bd19a757b39da733be69 100644 --- a/tests/segmentation/test_cxr8.py +++ b/tests/segmentation/test_cxr8.py @@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str): prefixes=[], possible_labels=[], expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, ) limit -= 1 diff --git a/tests/segmentation/test_drhagis.py b/tests/segmentation/test_drhagis.py index 8eaf73e00e215804521ebab6274c93212dfe2c34..b18d1571aa0c1515eeef376aec25d5b6db089cbb 100644 --- a/tests/segmentation/test_drhagis.py +++ b/tests/segmentation/test_drhagis.py @@ -82,7 +82,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["Fundus_Images/"], possible_labels=[], ) diff --git a/tests/segmentation/test_drionsdb.py b/tests/segmentation/test_drionsdb.py index 6c04191ae1afdbac19389d2a463f9288e4770a23..81c7731442d2444edab431eae6d3f3606392d8c1 100644 --- a/tests/segmentation/test_drionsdb.py +++ b/tests/segmentation/test_drionsdb.py @@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["images/image_"], possible_labels=[], ) diff --git a/tests/segmentation/test_drishtigs1.py b/tests/segmentation/test_drishtigs1.py index 8f704979ef308d40fac7d970f6fdea81c4c3fad9..5e4eb28922e200b5361474c5ed2e1f3e98c37283 100644 --- a/tests/segmentation/test_drishtigs1.py +++ b/tests/segmentation/test_drishtigs1.py @@ -89,7 +89,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=[ "Drishti-GS1_files/Training/Images/drishtiGS_", "Drishti-GS1_files/Test/Images/drishtiGS_", diff --git a/tests/segmentation/test_drive.py b/tests/segmentation/test_drive.py index 33fdd65ac6d95515898f372fb5d353868bc5bc0c..59239042e43bcb117c4e33ad6fab37915f47872f 100644 --- a/tests/segmentation/test_drive.py +++ b/tests/segmentation/test_drive.py @@ -83,7 +83,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, possible_labels=[], prefixes=["training/", "test/"], ) diff --git a/tests/segmentation/test_hrf.py b/tests/segmentation/test_hrf.py index 5f7a29bde9b30e80a35f41db00d3b1fa850024cb..fafd1ba165efc33e45cabe322d5a2ed75d75bd96 100644 --- a/tests/segmentation/test_hrf.py +++ b/tests/segmentation/test_hrf.py @@ -82,7 +82,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["images/"], possible_labels=[], ) diff --git a/tests/segmentation/test_iostar.py b/tests/segmentation/test_iostar.py index 082131f27c46956cc297bbe2891df98002b9756f..3776513b2bddb597cc41cf12c809bc58532b2b21 100644 --- a/tests/segmentation/test_iostar.py +++ b/tests/segmentation/test_iostar.py @@ -81,7 +81,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["image/STAR "], possible_labels=[], ) diff --git a/tests/segmentation/test_jsrt.py b/tests/segmentation/test_jsrt.py index e40b6d18c62ad50e0aae1eea3a01fcf6eff14c6e..a792994f2bb87f06283dd484329fdfb962315407 100644 --- a/tests/segmentation/test_jsrt.py +++ b/tests/segmentation/test_jsrt.py @@ -83,7 +83,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["All247images/JPC"], possible_labels=[], ) diff --git a/tests/segmentation/test_montgomery.py b/tests/segmentation/test_montgomery.py index 74043eb866ac19952de7e03d7f95dd8727058710..cb6223ac0187f1bec0c33dd0b8e757cbbb0e93cb 100644 --- a/tests/segmentation/test_montgomery.py +++ b/tests/segmentation/test_montgomery.py @@ -83,7 +83,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["CXR_png"], possible_labels=[], ) diff --git a/tests/segmentation/test_refuge.py b/tests/segmentation/test_refuge.py index df8548403abe01ad2fca2af54f5a8b0bcf0b2949..14b04b693ccc4d7841e47e16714daa6565f39621 100644 --- a/tests/segmentation/test_refuge.py +++ b/tests/segmentation/test_refuge.py @@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["Training400/", "REFUGE-Validation400/V", "Test400/T0"], possible_labels=[], ) diff --git a/tests/segmentation/test_rimoner3.py b/tests/segmentation/test_rimoner3.py index 895dc3c0869e1e00a3732ed5122ab01a5691b829..2a5bde7085341355741ef450efe00d0a27db85ed 100644 --- a/tests/segmentation/test_rimoner3.py +++ b/tests/segmentation/test_rimoner3.py @@ -88,7 +88,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=[ "Healthy/Stereo Images/N-", "Glaucoma and suspects/Stereo Images/", diff --git a/tests/segmentation/test_shenzhen.py b/tests/segmentation/test_shenzhen.py index 8de3a36ec5ec139c43eccc5e745b280f5607b408..71ea78f25aad3f051b4e14f90b6d8322b9023573 100644 --- a/tests/segmentation/test_shenzhen.py +++ b/tests/segmentation/test_shenzhen.py @@ -83,7 +83,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["CXR_png/CHNCXR_"], possible_labels=[], ) diff --git a/tests/segmentation/test_stare.py b/tests/segmentation/test_stare.py index 086970897a1ddf5436368f51c7a79a0919ac11da..6b6307d8e88d43c89ac6575a85e0ce5fe26e22b0 100644 --- a/tests/segmentation/test_stare.py +++ b/tests/segmentation/test_stare.py @@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str): batch_size=1, color_planes=3, expected_num_labels=1, - expected_meta_size=3, + expected_meta_size=1, prefixes=["stare-images/im0"], possible_labels=[], )