From 47239f3d4f3a1408772d4547e4a291d595246840 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 31 Jan 2024 12:44:16 +0100
Subject: [PATCH] [test] Check number of labels per sample in DataLoader

---
 tests/conftest.py        |  8 ++++++--
 tests/test_hivtb.py      |  1 +
 tests/test_indian.py     |  1 +
 tests/test_montgomery.py |  1 +
 tests/test_padchest.py   | 31 ++++++++++++-------------------
 tests/test_shenzhen.py   |  1 +
 tests/test_tbpoc.py      |  1 +
 tests/test_tbx11k.py     |  7 ++++++-
 8 files changed, 29 insertions(+), 22 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index b4f92331..febcc24f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -156,7 +156,7 @@ class DatabaseCheckers:
         prefixes: typing.Sequence[str],
         possible_labels: typing.Sequence[int],
     ):
-        """Run a simple consistence check on the data split.
+        """Run a simple consistency check on the data split.
 
         Parameters
         ----------
@@ -197,8 +197,9 @@ class DatabaseCheckers:
         color_planes: int,
         prefixes: typing.Sequence[str],
         possible_labels: typing.Sequence[int],
+        expected_num_labels: typing.Optional[int] = None,
     ):
-        """Check the consistence of an individual (loaded) batch.
+        """Check the consistency of an individual (loaded) batch.
 
         Parameters
         ----------
@@ -229,6 +230,9 @@ class DatabaseCheckers:
         assert "label" in batch[1]
         assert all([k in possible_labels for k in batch[1]["label"]])
 
+        if expected_num_labels:
+            assert len(batch[1]["label"]) == expected_num_labels
+
         assert "name" in batch[1]
         assert all(
             [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py
index 00b7a484..03066c2f 100644
--- a/tests/test_hivtb.py
+++ b/tests/test_hivtb.py
@@ -87,5 +87,6 @@ def test_loading(database_checkers, name: str, dataset: str):
             color_planes=1,
             prefixes=("HIV-TB_Algorithm_study_X-rays",),
             possible_labels=(0, 1),
+            expected_num_labels=1,
         )
         limit -= 1
diff --git a/tests/test_indian.py b/tests/test_indian.py
index 66ffdef3..8dd07159 100644
--- a/tests/test_indian.py
+++ b/tests/test_indian.py
@@ -92,5 +92,6 @@ def test_loading(database_checkers, name: str, dataset: str):
             color_planes=1,
             prefixes=("DatasetA/Training", "DatasetA/Testing"),
             possible_labels=(0, 1),
+            expected_num_labels=1,
         )
         limit -= 1
diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py
index 6933e982..91f8cc10 100644
--- a/tests/test_montgomery.py
+++ b/tests/test_montgomery.py
@@ -89,5 +89,6 @@ def test_loading(database_checkers, name: str, dataset: str):
             color_planes=1,
             prefixes=("CXR_png/MCUCXR_0",),
             possible_labels=(0, 1),
+            expected_num_labels=1,
         )
         limit -= 1
diff --git a/tests/test_padchest.py b/tests/test_padchest.py
index 97bf8a9c..5fc342e1 100644
--- a/tests/test_padchest.py
+++ b/tests/test_padchest.py
@@ -40,21 +40,18 @@ def test_protocol_consistency(
     )
 
 
+testdata = [
+    ("idiap", "train", 193),
+    ("idiap", "test", 1),
+    ("tb_idiap", "train", 1),
+    ("no_tb_idiap", "train", 14),
+    ("cardiomegaly_idiap", "train", 14),
+]
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
-@pytest.mark.parametrize(
-    "dataset",
-    ["train", "test"],
-)
-@pytest.mark.parametrize(
-    "name",
-    [
-        "idiap",
-        "tb_idiap",
-        "no_tb_idiap",
-        "cardiomegaly_idiap",
-    ],
-)
-def test_loading(database_checkers, name: str, dataset: str):
+@pytest.mark.parametrize("name,dataset,num_labels", testdata)
+def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
     datamodule = importlib.import_module(
         f".{name}", "mednet.config.data.padchest"
     ).datamodule
@@ -86,10 +83,6 @@ def test_loading(database_checkers, name: str, dataset: str):
                 color_planes=1,
                 prefixes=("",),
                 possible_labels=(0, 1),
+                expected_num_labels=num_labels,
             )
             limit -= 1
-
-
-# TODO: check size 1024x1024
-# TODO: check there are 14 binary labels (0, 1) (in some cases, in others much
-# more)...
diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py
index 6658831b..aeb9da85 100644
--- a/tests/test_shenzhen.py
+++ b/tests/test_shenzhen.py
@@ -89,5 +89,6 @@ def test_loading(database_checkers, name: str, dataset: str):
             color_planes=1,
             prefixes=("CXR_png/CHNCXR_0",),
             possible_labels=(0, 1),
+            expected_num_labels=1,
         )
         limit -= 1
diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py
index a2704fa7..58f762d5 100644
--- a/tests/test_tbpoc.py
+++ b/tests/test_tbpoc.py
@@ -93,5 +93,6 @@ def test_loading(database_checkers, name: str, dataset: str):
                 "TBPOC_CXR/tbpoc-",
             ),
             possible_labels=(0, 1),
+            expected_num_labels=1,
         )
         limit -= 1
diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py
index b1bbcc1f..231982e6 100644
--- a/tests/test_tbx11k.py
+++ b/tests/test_tbx11k.py
@@ -152,8 +152,9 @@ def check_loaded_batch(
     batch,
     batch_size: int,
     prefixes: typing.Sequence[str],
+    expected_num_labels: typing.Optional[int] = None,
 ):
-    """Check the consistence of an individual (loaded) batch.
+    """Check the consistency of an individual (loaded) batch.
 
     Parameters
     ----------
@@ -183,6 +184,9 @@ def check_loaded_batch(
     assert "label" in batch[1]
     assert all([k in (0, 1) for k in batch[1]["label"]])
 
+    if expected_num_labels:
+        assert len(batch[1]["label"]) == expected_num_labels
+
     assert "name" in batch[1]
     assert all(
         [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
@@ -269,5 +273,6 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
             batch,
             batch_size=1,
             prefixes=prefixes,
+            expected_num_labels=1,
         )
         limit -= 1
-- 
GitLab