From b42c68f9760f5e4b70d608bc054daa3bdc43ede4 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 26 Jun 2024 12:32:14 +0200
Subject: [PATCH] [tests] Update tests

---
 tests/conftest.py                     | 86 +++++++++++++++++++--------
 tests/segmentation/test_chasedb1.py   |  2 +-
 tests/segmentation/test_cxr8.py       |  2 +-
 tests/segmentation/test_drhagis.py    |  2 +-
 tests/segmentation/test_drionsdb.py   |  2 +-
 tests/segmentation/test_drishtigs1.py |  2 +-
 tests/segmentation/test_drive.py      |  2 +-
 tests/segmentation/test_hrf.py        |  2 +-
 tests/segmentation/test_iostar.py     |  2 +-
 tests/segmentation/test_jsrt.py       |  2 +-
 tests/segmentation/test_montgomery.py |  2 +-
 tests/segmentation/test_refuge.py     |  2 +-
 tests/segmentation/test_rimoner3.py   |  2 +-
 tests/segmentation/test_shenzhen.py   |  2 +-
 tests/segmentation/test_stare.py      |  2 +-
 15 files changed, 76 insertions(+), 38 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 98d102e7..fad88f46 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -185,30 +185,65 @@ class DatabaseCheckers:
 
         assert len(batch) == 2  # data, metadata
 
-        assert isinstance(batch[0], torch.Tensor)
-        assert batch[0].shape[0] == batch_size  # mini-batch size
-        assert batch[0].shape[1] == color_planes
-
-        if expected_image_shape:
-            assert all(
-                [data.shape == expected_image_shape for data in batch[0]],
-            )
-
-        assert isinstance(batch[1], dict)  # metadata
-        assert len(batch[1]) == expected_meta_size
-
-        assert "target" in batch[1]
-        if possible_labels:
-            assert all([k in possible_labels for k in batch[1]["target"]])
-
-        if expected_num_labels:
-            assert len(batch[1]["target"]) == expected_num_labels
-
-        assert "name" in batch[1]
-        if prefixes:
-            assert all(
-                [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]],
-            )
+        if isinstance(batch[0], dict):
+            assert isinstance(batch[0]["image"], torch.Tensor)
+
+            assert batch[0]["image"].shape[0] == batch_size  # mini-batch size
+            assert batch[0]["image"].shape[1] == color_planes
+
+            if expected_image_shape:
+                assert all(
+                    [data.shape == expected_image_shape for data in batch[0]["image"]],
+                )
+
+            assert isinstance(batch[1], dict)  # metadata
+            assert len(batch[1]) == expected_meta_size
+
+            assert "target" in batch[0]
+            if possible_labels:
+                assert all([k in possible_labels for k in batch[0]["target"]])
+
+            if expected_num_labels:
+                assert len(batch[0]["target"]) == expected_num_labels
+
+            assert "name" in batch[1]
+            if prefixes:
+                assert all(
+                    [
+                        any([k.startswith(j) for j in prefixes])
+                        for k in batch[1]["name"]
+                    ],
+                )
+
+        else:
+            assert isinstance(batch[0], torch.Tensor)
+
+            assert batch[0].shape[0] == batch_size  # mini-batch size
+            assert batch[0].shape[1] == color_planes
+
+            if expected_image_shape:
+                assert all(
+                    [data.shape == expected_image_shape for data in batch[0]],
+                )
+
+            assert isinstance(batch[1], dict)  # metadata
+            assert len(batch[1]) == expected_meta_size
+
+            assert "target" in batch[1]
+            if possible_labels:
+                assert all([k in possible_labels for k in batch[1]["target"]])
+
+            if expected_num_labels:
+                assert len(batch[1]["target"]) == expected_num_labels
+
+            assert "name" in batch[1]
+            if prefixes:
+                assert all(
+                    [
+                        any([k.startswith(j) for j in prefixes])
+                        for k in batch[1]["name"]
+                    ],
+                )
 
         # use the code below to view generated images
         # from torchvision.transforms.functional import to_pil_image
@@ -243,6 +278,9 @@ class DatabaseCheckers:
                     dataset_sample_index
                 ][0]
 
+                if isinstance(image_tensor, dict):
+                    image_tensor = image_tensor["image"]
+
                 histogram = []
                 for color_channel in image_tensor:
                     color_channel = numpy.multiply(
diff --git a/tests/segmentation/test_chasedb1.py b/tests/segmentation/test_chasedb1.py
index 04983002..c9bdd898 100644
--- a/tests/segmentation/test_chasedb1.py
+++ b/tests/segmentation/test_chasedb1.py
@@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["Image_"],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_cxr8.py b/tests/segmentation/test_cxr8.py
index eb6de32e..756afb50 100644
--- a/tests/segmentation/test_cxr8.py
+++ b/tests/segmentation/test_cxr8.py
@@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             prefixes=[],
             possible_labels=[],
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
         )
         limit -= 1
 
diff --git a/tests/segmentation/test_drhagis.py b/tests/segmentation/test_drhagis.py
index 8eaf73e0..b18d1571 100644
--- a/tests/segmentation/test_drhagis.py
+++ b/tests/segmentation/test_drhagis.py
@@ -82,7 +82,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["Fundus_Images/"],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_drionsdb.py b/tests/segmentation/test_drionsdb.py
index 6c04191a..81c77314 100644
--- a/tests/segmentation/test_drionsdb.py
+++ b/tests/segmentation/test_drionsdb.py
@@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["images/image_"],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_drishtigs1.py b/tests/segmentation/test_drishtigs1.py
index 8f704979..5e4eb289 100644
--- a/tests/segmentation/test_drishtigs1.py
+++ b/tests/segmentation/test_drishtigs1.py
@@ -89,7 +89,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=[
                 "Drishti-GS1_files/Training/Images/drishtiGS_",
                 "Drishti-GS1_files/Test/Images/drishtiGS_",
diff --git a/tests/segmentation/test_drive.py b/tests/segmentation/test_drive.py
index 33fdd65a..59239042 100644
--- a/tests/segmentation/test_drive.py
+++ b/tests/segmentation/test_drive.py
@@ -83,7 +83,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             possible_labels=[],
             prefixes=["training/", "test/"],
         )
diff --git a/tests/segmentation/test_hrf.py b/tests/segmentation/test_hrf.py
index 5f7a29bd..fafd1ba1 100644
--- a/tests/segmentation/test_hrf.py
+++ b/tests/segmentation/test_hrf.py
@@ -82,7 +82,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["images/"],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_iostar.py b/tests/segmentation/test_iostar.py
index 082131f2..3776513b 100644
--- a/tests/segmentation/test_iostar.py
+++ b/tests/segmentation/test_iostar.py
@@ -81,7 +81,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["image/STAR "],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_jsrt.py b/tests/segmentation/test_jsrt.py
index e40b6d18..a792994f 100644
--- a/tests/segmentation/test_jsrt.py
+++ b/tests/segmentation/test_jsrt.py
@@ -83,7 +83,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["All247images/JPC"],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_montgomery.py b/tests/segmentation/test_montgomery.py
index 74043eb8..cb6223ac 100644
--- a/tests/segmentation/test_montgomery.py
+++ b/tests/segmentation/test_montgomery.py
@@ -83,7 +83,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["CXR_png"],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_refuge.py b/tests/segmentation/test_refuge.py
index df854840..14b04b69 100644
--- a/tests/segmentation/test_refuge.py
+++ b/tests/segmentation/test_refuge.py
@@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["Training400/", "REFUGE-Validation400/V", "Test400/T0"],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_rimoner3.py b/tests/segmentation/test_rimoner3.py
index 895dc3c0..2a5bde70 100644
--- a/tests/segmentation/test_rimoner3.py
+++ b/tests/segmentation/test_rimoner3.py
@@ -88,7 +88,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=[
                 "Healthy/Stereo Images/N-",
                 "Glaucoma and suspects/Stereo Images/",
diff --git a/tests/segmentation/test_shenzhen.py b/tests/segmentation/test_shenzhen.py
index 8de3a36e..71ea78f2 100644
--- a/tests/segmentation/test_shenzhen.py
+++ b/tests/segmentation/test_shenzhen.py
@@ -83,7 +83,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["CXR_png/CHNCXR_"],
             possible_labels=[],
         )
diff --git a/tests/segmentation/test_stare.py b/tests/segmentation/test_stare.py
index 08697089..6b6307d8 100644
--- a/tests/segmentation/test_stare.py
+++ b/tests/segmentation/test_stare.py
@@ -84,7 +84,7 @@ def test_loading(database_checkers, name: str, dataset: str):
             batch_size=1,
             color_planes=3,
             expected_num_labels=1,
-            expected_meta_size=3,
+            expected_meta_size=1,
             prefixes=["stare-images/im0"],
             possible_labels=[],
         )
-- 
GitLab