From f3f94d64a57209dde50a50278c6b4c15452945a3 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 16 Feb 2024 18:17:36 +0100
Subject: [PATCH] [test] Handle RGB histograms comparison + select comparison
 method

---
 tests/conftest.py        | 41 ++++++++++++++++++++++++----------------
 tests/test_montgomery.py | 14 ++++++++++----
 2 files changed, 35 insertions(+), 20 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 183f238f..b51eafa6 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -203,7 +203,10 @@ class DatabaseCheckers:
 
     @staticmethod
     def check_image_quality(
-        datamodule, reference_histogram_file, pearson_coeff_threshold=0.005
+        datamodule,
+        reference_histogram_file,
+        compare_type="equal",
+        pearson_coeff_threshold=0.005,
     ):
         ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
 
@@ -226,21 +229,27 @@ class DatabaseCheckers:
                     dataset_sample_index
                 ][0]
 
-                image_tensor = numpy.multiply(image_tensor.numpy(), 255).astype(
-                    int
-                )
-                histogram = numpy.histogram(
-                    image_tensor, bins=256, range=(0, 256)
-                )[0].tolist()
-
-                # We cannot test if histograms are exactly equal because
-                # the torch.resize transform is inconsistent depending on the environment.
-                # assert histogram == ref_hist_data
-
-                # Compute pearson coefficients between histogram and reference
-                # and check the similarity within a certain threshold
-                pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
-                assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
+                histogram = []
+                for color_channel in image_tensor:
+                    color_channel = numpy.multiply(
+                        color_channel.numpy(), 255
+                    ).astype(int)
+                    histogram.extend(
+                        numpy.histogram(
+                            color_channel, bins=256, range=(0, 256)
+                        )[0].tolist()
+                    )
+
+                if compare_type == "statistical":
+                    # Compute pearson coefficients between histogram and reference
+                    # and check the similarity within a certain threshold
+                    pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
+                    assert (
+                        1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
+                    )
+
+                else:
+                    assert histogram == ref_hist_data
 
 
 @pytest.fixture
diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py
index 2600829f..4ce62580 100644
--- a/tests/test_montgomery.py
+++ b/tests/test_montgomery.py
@@ -95,7 +95,7 @@ def test_loading(database_checkers, name: str, dataset: str):
 
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
-def test_loaded_image_quality(database_checkers, datadir):
+def test_raw_transforms_image_quality(database_checkers, datadir):
     reference_histogram_file = str(
         datadir / "histograms/raw_data/histograms_montgomery_default.json"
     )
@@ -115,8 +115,8 @@ def test_loaded_image_quality(database_checkers, datadir):
     "model_name",
     [
         "alexnet",
-        # "densenet",
-        # "pasa",
+        "densenet",
+        "pasa",
     ],
 )
 def test_model_transforms_image_quality(database_checkers, datadir, model_name):
@@ -142,4 +142,10 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name):
 
     datamodule.model_transforms = model.model_transforms
     datamodule.setup("predict")
-    database_checkers.check_image_quality(datamodule, reference_histogram_file)
+
+    database_checkers.check_image_quality(
+        datamodule,
+        reference_histogram_file,
+        compare_type="statistical",
+        pearson_coeff_threshold=0.005,
+    )
-- 
GitLab