From 74d90ccb6d889a76c18001555547aa1e451027d7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 7 Feb 2024 16:08:25 +0100
Subject: [PATCH] [test] Add tests for raw data loading image quality

---
 tests/data/lfs           |  2 +-
 tests/test_hivtb.py      | 40 ++++++++++++++++++++++++++++++++++++++
 tests/test_indian.py     | 38 +++++++++++++++++++++++-------------
 tests/test_montgomery.py | 38 +++++++++++++++++++++++-------------
 tests/test_nih_cxr14.py  | 42 +++++++++++++++++++++++++++++++++++++++-
 tests/test_padchest.py   | 40 ++++++++++++++++++++++++++++++++++++++
 tests/test_shenzhen.py   | 38 +++++++++++++++++++++++-------------
 tests/test_tbpoc.py      | 40 ++++++++++++++++++++++++++++++++++++++
 tests/test_tbx11k.py     | 40 ++++++++++++++++++++++++++++++++++++++
 9 files changed, 277 insertions(+), 41 deletions(-)

diff --git a/tests/data/lfs b/tests/data/lfs
index 05344d20..e2e4a98d 160000
--- a/tests/data/lfs
+++ b/tests/data/lfs
@@ -1 +1 @@
-Subproject commit 05344d20182ad4169fc5b9c38052d629aded30ed
+Subproject commit e2e4a98d675ec61dac44c339c28d91fcb180b398
diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py
index 03066c2f..6373d0ee 100644
--- a/tests/test_hivtb.py
+++ b/tests/test_hivtb.py
@@ -7,6 +7,10 @@ import importlib
 
 import pytest
 
+from torchvision.transforms.functional import to_pil_image
+
+from mednet.data.split import JSONDatabaseSplit
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -90,3 +94,39 @@ def test_loading(database_checkers, name: str, dataset: str):
             expected_num_labels=1,
         )
         limit -= 1
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
+def test_loaded_image_quality(datadir):
+    reference_histogram_file = str(
+        datadir / "lfs/histograms/raw_data/histograms_hivtb_fold_0.json"
+    )
+    ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
+
+    datamodule = importlib.import_module(
+        ".fold_0", "mednet.config.data.hivtb"
+    ).datamodule
+
+    datamodule.model_transforms = []
+    datamodule.setup("predict")
+
+    for split_name in ref_histogram_splits:
+        datamodule_split = datamodule.splits[split_name]
+
+        loader = datamodule_split[0][1]
+
+        for ref_data in ref_histogram_splits[split_name]:
+            sample_path = ref_data[0]
+            ref_histogram = ref_data[1]
+
+            test_sample = (
+                sample_path,
+                -1,
+            )  # Need to specify a label even if not used.
+            image_data = loader.sample(test_sample)[0]
+
+            img = to_pil_image(image_data)
+
+            histogram = img.histogram()
+
+            assert histogram == ref_histogram
diff --git a/tests/test_indian.py b/tests/test_indian.py
index a35cfdb3..452aa6c6 100644
--- a/tests/test_indian.py
+++ b/tests/test_indian.py
@@ -7,11 +7,12 @@ dataset A/dataset B) dataset.
 """
 
 import importlib
-import json
 
 import pytest
 
-from PIL import Image
+from torchvision.transforms.functional import to_pil_image
+
+from mednet.data.split import JSONDatabaseSplit
 
 
 def id_function(val):
@@ -102,6 +103,11 @@ def test_loading(database_checkers, name: str, dataset: str):
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.indian")
 def test_loaded_image_quality(datadir):
+    reference_histogram_file = str(
+        datadir / "lfs/histograms/raw_data/histograms_indian_default.json"
+    )
+    ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
+
     datamodule = importlib.import_module(
         ".default", "mednet.config.data.indian"
     ).datamodule
@@ -109,17 +115,23 @@ def test_loaded_image_quality(datadir):
     datamodule.model_transforms = []
     datamodule.setup("predict")
 
-    loader = datamodule.splits["train"][0][1]
-    first_sample = datamodule.splits["train"][0][0][0]
-    image_data = loader.sample(first_sample)[0].numpy()[
-        0, :, :
-    ]  # PIL expects grayscale to not have any leading dim
-    img = Image.fromarray(image_data, mode="L")
+    for split_name in ref_histogram_splits:
+        datamodule_split = datamodule.splits[split_name]
+
+        loader = datamodule_split[0][1]
+
+        for ref_data in ref_histogram_splits[split_name]:
+            sample_path = ref_data[0]
+            ref_histogram = ref_data[1]
+
+            test_sample = (
+                sample_path,
+                -1,
+            )  # Need to specify a label even if not used.
+            image_data = loader.sample(test_sample)[0]
 
-    histogram = img.histogram()
+            img = to_pil_image(image_data)
 
-    reference_histogram_file = str(datadir / "histogram_indian.json")
-    with open(reference_histogram_file) as i_f:
-        ref_histogram = json.load(i_f)["histogram"]
+            histogram = img.histogram()
 
-    assert histogram == ref_histogram
+            assert histogram == ref_histogram
diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py
index 349898d9..91c7aa59 100644
--- a/tests/test_montgomery.py
+++ b/tests/test_montgomery.py
@@ -4,11 +4,12 @@
 """Tests for Montgomery dataset."""
 
 import importlib
-import json
 
 import pytest
 
-from PIL import Image
+from torchvision.transforms.functional import to_pil_image
+
+from mednet.data.split import JSONDatabaseSplit
 
 
 def id_function(val):
@@ -99,6 +100,11 @@ def test_loading(database_checkers, name: str, dataset: str):
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 def test_loaded_image_quality(datadir):
+    reference_histogram_file = str(
+        datadir / "lfs/histograms/raw_data/histograms_montgomery_default.json"
+    )
+    ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
+
     datamodule = importlib.import_module(
         ".default", "mednet.config.data.montgomery"
     ).datamodule
@@ -106,17 +112,23 @@ def test_loaded_image_quality(datadir):
     datamodule.model_transforms = []
     datamodule.setup("predict")
 
-    loader = datamodule.splits["train"][0][1]
-    first_sample = datamodule.splits["train"][0][0][0]
-    image_data = loader.sample(first_sample)[0].numpy()[
-        0, :, :
-    ]  # PIL expects grayscale to not have any leading dim
-    img = Image.fromarray(image_data, mode="L")
+    for split_name in ref_histogram_splits:
+        datamodule_split = datamodule.splits[split_name]
+
+        loader = datamodule_split[0][1]
+
+        for ref_data in ref_histogram_splits[split_name]:
+            sample_path = ref_data[0]
+            ref_histogram = ref_data[1]
+
+            test_sample = (
+                sample_path,
+                -1,
+            )  # Need to specify a label even if not used.
+            image_data = loader.sample(test_sample)[0]
 
-    histogram = img.histogram()
+            img = to_pil_image(image_data)
 
-    reference_histogram_file = str(datadir / "histogram_montgomery.json")
-    with open(reference_histogram_file) as i_f:
-        ref_histogram = json.load(i_f)["histogram"]
+            histogram = img.histogram()
 
-    assert histogram == ref_histogram
+            assert histogram == ref_histogram
diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py
index dc73aa29..a2e014f3 100644
--- a/tests/test_nih_cxr14.py
+++ b/tests/test_nih_cxr14.py
@@ -7,6 +7,10 @@ import importlib
 
 import pytest
 
+from torchvision.transforms.functional import to_pil_image
+
+from mednet.data.split import JSONDatabaseSplit
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -44,7 +48,7 @@ testdata = [
 ]
 
 
-@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
+@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
 @pytest.mark.parametrize("name,dataset,num_labels", testdata)
 def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
     datamodule = importlib.import_module(
@@ -70,3 +74,39 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
             expected_image_shape=(1, 1024, 1024),
         )
         limit -= 1
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
+def test_loaded_image_quality(datadir):
+    reference_histogram_file = str(
+        datadir / "lfs/histograms/raw_data/histograms_nih_cxr14_default.json"
+    )
+    ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
+
+    datamodule = importlib.import_module(
+        ".default", "mednet.config.data.nih_cxr14"
+    ).datamodule
+
+    datamodule.model_transforms = []
+    datamodule.setup("predict")
+
+    for split_name in ref_histogram_splits:
+        datamodule_split = datamodule.splits[split_name]
+
+        loader = datamodule_split[0][1]
+
+        for ref_data in ref_histogram_splits[split_name]:
+            sample_path = ref_data[0]
+            ref_histogram = ref_data[1]
+
+            test_sample = (
+                sample_path,
+                -1,
+            )  # Need to specify a label even if not used.
+            image_data = loader.sample(test_sample)[0]
+
+            img = to_pil_image(image_data)
+
+            histogram = img.histogram()
+
+            assert histogram == ref_histogram
diff --git a/tests/test_padchest.py b/tests/test_padchest.py
index 98eaeb2d..47494d94 100644
--- a/tests/test_padchest.py
+++ b/tests/test_padchest.py
@@ -7,6 +7,10 @@ import importlib
 
 import pytest
 
+from torchvision.transforms.functional import to_pil_image
+
+from mednet.data.split import JSONDatabaseSplit
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -75,3 +79,39 @@ def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
                 expected_num_labels=num_labels,
             )
             limit -= 1
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
+def test_loaded_image_quality(datadir):
+    reference_histogram_file = str(
+        datadir / "lfs/histograms/raw_data/histograms_padchest_idiap.json"
+    )
+    ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
+
+    datamodule = importlib.import_module(
+        ".idiap", "mednet.config.data.padchest"
+    ).datamodule
+
+    datamodule.model_transforms = []
+    datamodule.setup("predict")
+
+    for split_name in ref_histogram_splits:
+        datamodule_split = datamodule.splits[split_name]
+
+        loader = datamodule_split[0][1]
+
+        for ref_data in ref_histogram_splits[split_name]:
+            sample_path = ref_data[0]
+            ref_histogram = ref_data[1]
+
+            test_sample = (
+                sample_path,
+                -1,
+            )  # Need to specify a label even if not used.
+            image_data = loader.sample(test_sample)[0]
+
+            img = to_pil_image(image_data)
+
+            histogram = img.histogram()
+
+            assert histogram == ref_histogram
diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py
index ff0a8af2..e2283fad 100644
--- a/tests/test_shenzhen.py
+++ b/tests/test_shenzhen.py
@@ -4,11 +4,12 @@
 """Tests for Shenzhen dataset."""
 
 import importlib
-import json
 
 import pytest
 
-from PIL import Image
+from torchvision.transforms.functional import to_pil_image
+
+from mednet.data.split import JSONDatabaseSplit
 
 
 def id_function(val):
@@ -99,6 +100,11 @@ def test_loading(database_checkers, name: str, dataset: str):
 
 @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
 def test_loaded_image_quality(datadir):
+    reference_histogram_file = str(
+        datadir / "lfs/histograms/raw_data/histograms_shenzhen_default.json"
+    )
+    ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
+
     datamodule = importlib.import_module(
         ".default", "mednet.config.data.shenzhen"
     ).datamodule
@@ -106,17 +112,23 @@ def test_loaded_image_quality(datadir):
     datamodule.model_transforms = []
     datamodule.setup("predict")
 
-    loader = datamodule.splits["train"][0][1]
-    first_sample = datamodule.splits["train"][0][0][0]
-    image_data = loader.sample(first_sample)[0].numpy()[
-        0, :, :
-    ]  # PIL expects grayscale to not have any leading dim
-    img = Image.fromarray(image_data, mode="L")
+    for split_name in ref_histogram_splits:
+        datamodule_split = datamodule.splits[split_name]
+
+        loader = datamodule_split[0][1]
+
+        for ref_data in ref_histogram_splits[split_name]:
+            sample_path = ref_data[0]
+            ref_histogram = ref_data[1]
+
+            test_sample = (
+                sample_path,
+                -1,
+            )  # Need to specify a label even if not used.
+            image_data = loader.sample(test_sample)[0]
 
-    histogram = img.histogram()
+            img = to_pil_image(image_data)
 
-    reference_histogram_file = str(datadir / "histogram_shenzhen.json")
-    with open(reference_histogram_file) as i_f:
-        ref_histogram = json.load(i_f)["histogram"]
+            histogram = img.histogram()
 
-    assert histogram == ref_histogram
+            assert histogram == ref_histogram
diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py
index 58f762d5..75eaeca7 100644
--- a/tests/test_tbpoc.py
+++ b/tests/test_tbpoc.py
@@ -7,6 +7,10 @@ import importlib
 
 import pytest
 
+from torchvision.transforms.functional import to_pil_image
+
+from mednet.data.split import JSONDatabaseSplit
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -96,3 +100,39 @@ def test_loading(database_checkers, name: str, dataset: str):
             expected_num_labels=1,
         )
         limit -= 1
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
+def test_loaded_image_quality(datadir):
+    reference_histogram_file = str(
+        datadir / "lfs/histograms/raw_data/histograms_tbpoc_fold_0.json"
+    )
+    ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
+
+    datamodule = importlib.import_module(
+        ".fold_0", "mednet.config.data.tbpoc"
+    ).datamodule
+
+    datamodule.model_transforms = []
+    datamodule.setup("predict")
+
+    for split_name in ref_histogram_splits:
+        datamodule_split = datamodule.splits[split_name]
+
+        loader = datamodule_split[0][1]
+
+        for ref_data in ref_histogram_splits[split_name]:
+            sample_path = ref_data[0]
+            ref_histogram = ref_data[1]
+
+            test_sample = (
+                sample_path,
+                -1,
+            )  # Need to specify a label even if not used.
+            image_data = loader.sample(test_sample)[0]
+
+            img = to_pil_image(image_data)
+
+            histogram = img.histogram()
+
+            assert histogram == ref_histogram
diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py
index 6c022dba..7b2a60c9 100644
--- a/tests/test_tbx11k.py
+++ b/tests/test_tbx11k.py
@@ -9,6 +9,10 @@ import typing
 import pytest
 import torch
 
+from torchvision.transforms.functional import to_pil_image
+
+from mednet.data.split import JSONDatabaseSplit
+
 
 def id_function(val):
     if isinstance(val, (dict, tuple)):
@@ -283,3 +287,39 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
             expected_image_shape=(3, 512, 512),
         )
         limit -= 1
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
+def test_loaded_image_quality(datadir):
+    reference_histogram_file = str(
+        datadir / "lfs/histograms/raw_data/histograms_tbx11k_v1_fold_0.json"
+    )
+    ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
+
+    datamodule = importlib.import_module(
+        ".v1_fold_0", "mednet.config.data.tbx11k"
+    ).datamodule
+
+    datamodule.model_transforms = []
+    datamodule.setup("predict")
+
+    for split_name in ref_histogram_splits:
+        datamodule_split = datamodule.splits[split_name]
+
+        loader = datamodule_split[0][1]
+
+        for ref_data in ref_histogram_splits[split_name]:
+            sample_path = ref_data[0]
+            ref_histogram = ref_data[1]
+
+            test_sample = (
+                sample_path,
+                -1,
+            )  # Need to specify a label even if not used.
+            image_data = loader.sample(test_sample)[0]
+
+            img = to_pil_image(image_data)
+
+            histogram = img.histogram()
+
+            assert histogram == ref_histogram
-- 
GitLab