From 45f95973245bacd8703d49d56fab12dabbebeec7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 2 Feb 2024 10:52:32 +0100
Subject: [PATCH] [test] Test loaded images against reference histograms

---
 tests/test_indian.py     | 28 ++++++++++++++++++++++++++++
 tests/test_montgomery.py | 28 ++++++++++++++++++++++++++++
 tests/test_shenzhen.py   | 28 ++++++++++++++++++++++++++++
 3 files changed, 84 insertions(+)

diff --git a/tests/test_indian.py b/tests/test_indian.py
index 8dd07159..a35cfdb3 100644
--- a/tests/test_indian.py
+++ b/tests/test_indian.py
@@ -7,9 +7,12 @@ dataset A/dataset B) dataset.
 """
 
 import importlib
+import json
 
 import pytest
 
+from PIL import Image
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -95,3 +98,28 @@ def test_loading(database_checkers, name: str, dataset: str):
             expected_num_labels=1,
         )
         limit -= 1
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
+def test_loaded_image_quality(datadir):
+    datamodule = importlib.import_module(
+        ".default", "mednet.config.data.indian"
+    ).datamodule
+
+    datamodule.model_transforms = []
+    datamodule.setup("predict")
+
+    loader = datamodule.splits["train"][0][1]
+    first_sample = datamodule.splits["train"][0][0][0]
+    image_data = loader.sample(first_sample)[0].numpy()[
+        0, :, :
+    ]  # PIL expects grayscale to not have any leading dim
+    img = Image.fromarray(image_data, mode="L")
+
+    histogram = img.histogram()
+
+    reference_histogram_file = str(datadir / "histogram_indian.json")
+    with open(reference_histogram_file) as i_f:
+        ref_histogram = json.load(i_f)["histogram"]
+
+    assert histogram == ref_histogram
diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py
index 91f8cc10..349898d9 100644
--- a/tests/test_montgomery.py
+++ b/tests/test_montgomery.py
@@ -4,9 +4,12 @@
 """Tests for Montgomery dataset."""
 
 import importlib
+import json
 
 import pytest
 
+from PIL import Image
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -92,3 +95,28 @@ def test_loading(database_checkers, name: str, dataset: str):
             expected_num_labels=1,
         )
         limit -= 1
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+def test_loaded_image_quality(datadir):
+    datamodule = importlib.import_module(
+        ".default", "mednet.config.data.montgomery"
+    ).datamodule
+
+    datamodule.model_transforms = []
+    datamodule.setup("predict")
+
+    loader = datamodule.splits["train"][0][1]
+    first_sample = datamodule.splits["train"][0][0][0]
+    image_data = loader.sample(first_sample)[0].numpy()[
+        0, :, :
+    ]  # PIL expects grayscale to not have any leading dim
+    img = Image.fromarray(image_data, mode="L")
+
+    histogram = img.histogram()
+
+    reference_histogram_file = str(datadir / "histogram_montgomery.json")
+    with open(reference_histogram_file) as i_f:
+        ref_histogram = json.load(i_f)["histogram"]
+
+    assert histogram == ref_histogram
diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py
index aeb9da85..ff0a8af2 100644
--- a/tests/test_shenzhen.py
+++ b/tests/test_shenzhen.py
@@ -4,9 +4,12 @@
 """Tests for Shenzhen dataset."""
 
 import importlib
+import json
 
 import pytest
 
+from PIL import Image
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -92,3 +95,28 @@ def test_loading(database_checkers, name: str, dataset: str):
             expected_num_labels=1,
         )
         limit -= 1
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
+def test_loaded_image_quality(datadir):
+    datamodule = importlib.import_module(
+        ".default", "mednet.config.data.shenzhen"
+    ).datamodule
+
+    datamodule.model_transforms = []
+    datamodule.setup("predict")
+
+    loader = datamodule.splits["train"][0][1]
+    first_sample = datamodule.splits["train"][0][0][0]
+    image_data = loader.sample(first_sample)[0].numpy()[
+        0, :, :
+    ]  # PIL expects grayscale to not have any leading dim
+    img = Image.fromarray(image_data, mode="L")
+
+    histogram = img.histogram()
+
+    reference_histogram_file = str(datadir / "histogram_shenzhen.json")
+    with open(reference_histogram_file) as i_f:
+        ref_histogram = json.load(i_f)["histogram"]
+
+    assert histogram == ref_histogram
-- 
GitLab