From f3f94d64a57209dde50a50278c6b4c15452945a3 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Fri, 16 Feb 2024 18:17:36 +0100 Subject: [PATCH] [test] Handle RGB histograms comparison + select comparison method --- tests/conftest.py | 41 ++++++++++++++++++++++++---------------- tests/test_montgomery.py | 14 ++++++++++---- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 183f238f..b51eafa6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,7 +203,10 @@ class DatabaseCheckers: @staticmethod def check_image_quality( - datamodule, reference_histogram_file, pearson_coeff_threshold=0.005 + datamodule, + reference_histogram_file, + compare_type="equal", + pearson_coeff_threshold=0.005, ): ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) @@ -226,21 +229,27 @@ class DatabaseCheckers: dataset_sample_index ][0] - image_tensor = numpy.multiply(image_tensor.numpy(), 255).astype( - int - ) - histogram = numpy.histogram( - image_tensor, bins=256, range=(0, 256) - )[0].tolist() - - # We cannot test if histograms are exactly equal because - # the torch.resize transform is inconsistent depending on the environment. - # assert histogram == ref_hist_data - - # Compute pearson coefficients between histogram and reference - # and check the similarity within a certain threshold - pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) - assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 + histogram = [] + for color_channel in image_tensor: + color_channel = numpy.multiply( + color_channel.numpy(), 255 + ).astype(int) + histogram.extend( + numpy.histogram( + color_channel, bins=256, range=(0, 256) + )[0].tolist() + ) + + if compare_type == "statistical": + # Compute pearson coefficients between histogram and reference + # and check the similarity within a certain threshold + pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) + assert ( + 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 + ) + + else: + assert histogram == ref_hist_data @pytest.fixture diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 2600829f..4ce62580 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -95,7 +95,7 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_loaded_image_quality(database_checkers, datadir): +def test_raw_transforms_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_montgomery_default.json" ) @@ -115,8 +115,8 @@ def test_loaded_image_quality(database_checkers, datadir): "model_name", [ "alexnet", - # "densenet", - # "pasa", + "densenet", + "pasa", ], ) def test_model_transforms_image_quality(database_checkers, datadir, model_name): @@ -142,4 +142,10 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name): datamodule.model_transforms = model.model_transforms datamodule.setup("predict") - database_checkers.check_image_quality(datamodule, reference_histogram_file) + + database_checkers.check_image_quality( + datamodule, + reference_histogram_file, + compare_type="statistical", + pearson_coeff_threshold=0.005, + ) -- GitLab