Skip to content
Snippets Groups Projects
Commit f3f94d64 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[test] Handle RGB histograms comparison + select comparison method

parent ab524a4e
No related branches found
No related tags found
3 merge requests!19Fix Issues when running tests on the CI,!18Update tests,!16Make square centre-padding a model transform
Pipeline #84276 failed
......@@ -203,7 +203,10 @@ class DatabaseCheckers:
@staticmethod
def check_image_quality(
datamodule, reference_histogram_file, pearson_coeff_threshold=0.005
datamodule,
reference_histogram_file,
compare_type="equal",
pearson_coeff_threshold=0.005,
):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
......@@ -226,21 +229,27 @@ class DatabaseCheckers:
dataset_sample_index
][0]
image_tensor = numpy.multiply(image_tensor.numpy(), 255).astype(
int
)
histogram = numpy.histogram(
image_tensor, bins=256, range=(0, 256)
)[0].tolist()
# We cannot test if histograms are exactly equal because
# the torch.resize transform is inconsistent depending on the environment.
# assert histogram == ref_hist_data
# Compute pearson coefficients between histogram and reference
# and check the similarity within a certain threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
histogram = []
for color_channel in image_tensor:
color_channel = numpy.multiply(
color_channel.numpy(), 255
).astype(int)
histogram.extend(
numpy.histogram(
color_channel, bins=256, range=(0, 256)
)[0].tolist()
)
if compare_type == "statistical":
# Compute pearson coefficients between histogram and reference
# and check the similarity within a certain threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
assert (
1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
)
else:
assert histogram == ref_hist_data
@pytest.fixture
......
......@@ -95,7 +95,7 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loaded_image_quality(database_checkers, datadir):
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_montgomery_default.json"
)
......@@ -115,8 +115,8 @@ def test_loaded_image_quality(database_checkers, datadir):
"model_name",
[
"alexnet",
# "densenet",
# "pasa",
"densenet",
"pasa",
],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
......@@ -142,4 +142,10 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name):
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment