From fb46d40322aa9e3194356f07ae5f3aa78f38d159 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 1 Feb 2024 10:55:16 +0100
Subject: [PATCH] [test] Add checks for specific image shapes

---
 tests/conftest.py       | 25 ++++++++++++++++++++++---
 tests/test_nih_cxr14.py | 31 ++++++++++++++-----------------
 tests/test_tbx11k.py    | 17 ++++++++++++-----
 3 files changed, 48 insertions(+), 25 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index febcc24f..40276ff1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -160,9 +160,19 @@ class DatabaseCheckers:
 
         Parameters
         ----------
+<<<<<<< HEAD
         split
             An instance of DatabaseSplit.
         lengths
+=======
+
+        make_split
+            A database specific function that takes a split name and returns
+            the loaded database split.
+        split_filename
+            This is the split we will check.
+        lenghts
+>>>>>>> 91bcad6 ([test] Add checks for specific image shapes)
             A dictionary that contains keys matching those of the split (this will
             be checked).  The values of the dictionary should correspond to the
             sizes of each of the datasets in the split.
@@ -197,13 +207,13 @@ class DatabaseCheckers:
         color_planes: int,
         prefixes: typing.Sequence[str],
         possible_labels: typing.Sequence[int],
-        expected_num_labels: typing.Optional[int] = None,
+        expected_num_labels: int,
+        expected_image_shape: typing.Optional[tuple[int, ...]] = None,
     ):
         """Check the consistency of an individual (loaded) batch.
 
         Parameters
         ----------
-
         batch
             The loaded batch to be checked.
         batch_size
@@ -215,15 +225,24 @@ class DatabaseCheckers:
             prefixes.
         possible_labels
             These are the list of possible labels contained in any split.
+        expected_num_labels
+            The expected number of labels each sample should have.
+        expected_image_shape
+            The expected shape of the image (num_channels, width, height).
         """
 
         assert len(batch) == 2  # data, metadata
 
         assert isinstance(batch[0], torch.Tensor)
         assert batch[0].shape[0] == batch_size  # mini-batch size
-        assert batch[0].shape[1] == color_planes  # grayscale images
+        assert batch[0].shape[1] == color_planes
         assert batch[0].shape[2] == batch[0].shape[3]  # image is square
 
+        if expected_image_shape:
+            assert all(
+                [data.shape == expected_image_shape for data in batch[0]]
+            )
+
         assert isinstance(batch[1], dict)  # metadata
         assert len(batch[1]) == 2  # label and name
 
diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py
index 8188c7cb..e6ad485d 100644
--- a/tests/test_nih_cxr14.py
+++ b/tests/test_nih_cxr14.py
@@ -35,22 +35,18 @@ def test_protocol_consistency(
     )
 
 
-@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
-@pytest.mark.parametrize(
-    "dataset",
-    [
-        "train",
-        "validation",
-        "test",
-    ],
-)
-@pytest.mark.parametrize(
-    "name",
-    [
-        "default",
-    ],
-)
-def test_loading(database_checkers, name: str, dataset: str):
+testdata = [
+    ("default", "train", 14),
+    ("default", "validation", 14),
+    ("default", "test", 14),
+    ("cardiomegaly", "train", 14),
+    ("cardiomegaly", "validation", 14),
+]
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
+@pytest.mark.parametrize("name,dataset,num_labels", testdata)
+def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
     datamodule = importlib.import_module(
         f".{name}", "mednet.config.data.nih_cxr14"
     ).datamodule
@@ -70,9 +66,10 @@ def test_loading(database_checkers, name: str, dataset: str):
             color_planes=1,
             prefixes=("images/000",),
             possible_labels=(0, 1),
+            expected_num_labels=num_labels,
+            expected_image_shape=(1, 1024, 1024),
         )
         limit -= 1
 
 
 # TODO: check size 1024x1024
-# TODO: check there are 14 binary labels (0, 1)
diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py
index 231982e6..6c022dba 100644
--- a/tests/test_tbx11k.py
+++ b/tests/test_tbx11k.py
@@ -151,14 +151,16 @@ def test_protocol_consistency(
 def check_loaded_batch(
     batch,
     batch_size: int,
+    color_planes: int,
     prefixes: typing.Sequence[str],
-    expected_num_labels: typing.Optional[int] = None,
+    possible_labels: typing.Sequence[int],
+    expected_num_labels: int,
+    expected_image_shape: typing.Optional[tuple[int, ...]] = None,
 ):
     """Check the consistency of an individual (loaded) batch.
 
     Parameters
     ----------
-
     batch
         The loaded batch to be checked.
     batch_size
@@ -172,9 +174,11 @@ def check_loaded_batch(
 
     assert isinstance(batch[0], torch.Tensor)
     assert batch[0].shape[0] == batch_size  # mini-batch size
-    assert batch[0].shape[1] == 3  # grayscale images
+    assert batch[0].shape[1] == color_planes
     assert batch[0].shape[2] == batch[0].shape[3]  # image is square
-    assert batch[0].shape[2] == 512  # image is 512 pixels large
+
+    if expected_image_shape:
+        assert all([data.shape == expected_image_shape for data in batch[0]])
 
     assert isinstance(batch[1], dict)  # metadata
     assert (
@@ -182,7 +186,7 @@ def check_loaded_batch(
     )  # label, name and radiological sign bounding-boxes
 
     assert "label" in batch[1]
-    assert all([k in (0, 1) for k in batch[1]["label"]])
+    assert all([k in possible_labels for k in batch[1]["label"]])
 
     if expected_num_labels:
         assert len(batch[1]["label"]) == expected_num_labels
@@ -272,7 +276,10 @@ def test_loading(name: str, dataset: str, prefixes: typing.Sequence[str]):
         check_loaded_batch(
             batch,
             batch_size=1,
+            color_planes=3,
             prefixes=prefixes,
+            possible_labels=(0, 1),
             expected_num_labels=1,
+            expected_image_shape=(3, 512, 512),
         )
         limit -= 1
-- 
GitLab