From 513bed9df5b13eed930d3d49592a26d0cdb54084 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 2 Aug 2023 21:02:33 +0200
Subject: [PATCH] [tests.test_nih_cxr14] Adapt to new datamodule syntax

---
 tests/test_nih_cxr14.py                | 135 +++++++++++++------------
 tests/{test_pc.py => test_padchest.py} |   0
 2 files changed, 70 insertions(+), 65 deletions(-)
 rename tests/{test_pc.py => test_padchest.py} (100%)

diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py
index d7cc149a..242ac7e9 100644
--- a/tests/test_nih_cxr14.py
+++ b/tests/test_nih_cxr14.py
@@ -1,72 +1,77 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-"""Tests for NIH CXR14 dataset."""
+"""Tests for NIH CXR-14 dataset."""
 
-import pytest
-
-
-@pytest.mark.skip(reason="Test need to be updated")
-def test_protocol_consistency():
-    from ptbench.data.nih_cxr14_re import dataset
-
-    # Default protocol
-    subset = dataset.subsets("default")
-    assert len(subset) == 3
-
-    assert "train" in subset
-    assert len(subset["train"]) == 98637
-    for s in subset["train"]:
-        assert s.key.startswith("images/000")
-
-    assert "validation" in subset
-    assert len(subset["validation"]) == 6350
-    for s in subset["validation"]:
-        assert s.key.startswith("images/000")
-
-    assert "test" in subset
-    assert len(subset["test"]) == 4054
-    for s in subset["test"]:
-        assert s.key.startswith("images/000")
-
-    # Check labels
-    for s in subset["train"]:
-        for element in list(set(s.label)):
-            assert element in [0.0, 1.0]
+import importlib
 
-    for s in subset["validation"]:
-        for element in list(set(s.label)):
-            assert element in [0.0, 1.0]
-
-    for s in subset["test"]:
-        for element in list(set(s.label)):
-            assert element in [0.0, 1.0]
-
-
-@pytest.mark.skip(reason="Test need to be updated")
-@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
-def test_loading():
-    from ptbench.data.nih_cxr14_re import dataset
-
-    def _check_size(size):
-        if size == (1024, 1024):
-            return True
-        return False
-
-    def _check_sample(s):
-        data = s.data
-        assert isinstance(data, dict)
-        assert len(data) == 2
-
-        assert "data" in data
-        assert _check_size(data["data"].size)  # Check size
-        assert data["data"].mode == "RGB"  # Check colors
-
-        assert "label" in data
-        assert len(data["label"]) == 14  # Check labels
+import pytest
 
-    limit = 30  # use this to limit testing to first images only, else None
 
-    subset = dataset.subsets("default")
-    for s in subset["train"][:limit]:
-        _check_sample(s)
+def id_function(val):
+    if isinstance(val, dict):
+        return str(val)
+    return repr(val)
+
+
+@pytest.mark.parametrize(
+    "split,lenghts",
+    [
+        ("default", dict(train=98637, validation=6350, test=4054)),
+    ],
+    ids=id_function,  # just changes how pytest prints it
+)
+def test_protocol_consistency(
+    database_checkers, split: str, lenghts: dict[str, int]
+):
+    from ptbench.data.nih_cxr14.datamodule import make_split
+
+    database_checkers.check_split(
+        make_split(f"{split}.json.bz2"),
+        lengths=lenghts,
+        prefixes=("images/000",),
+        possible_labels=(0, 1),
+    )
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
+@pytest.mark.parametrize(
+    "dataset",
+    [
+        "train",
+        "validation",
+        "test",
+    ],
+)
+@pytest.mark.parametrize(
+    "name",
+    [
+        "default",
+    ],
+)
+def test_loading(database_checkers, name: str, dataset: str):
+    datamodule = importlib.import_module(
+        f".{name}", "ptbench.data.nih_cxr14"
+    ).datamodule
+
+    datamodule.model_transforms = []  # should be done before setup()
+    datamodule.setup("predict")  # sets up all datasets
+
+    loader = datamodule.predict_dataloader()[dataset]
+
+    limit = 3  # limit load checking
+    for batch in loader:
+        if limit == 0:
+            break
+        database_checkers.check_loaded_batch(
+            batch,
+            batch_size=1,
+            color_planes=1,
+            prefixes=("images/000",),
+            possible_labels=(0, 1),
+        )
+        limit -= 1
+
+
+# TODO: check size 1024x1024
+# TODO: check there are 14 binary labels (0, 1)
diff --git a/tests/test_pc.py b/tests/test_padchest.py
similarity index 100%
rename from tests/test_pc.py
rename to tests/test_padchest.py
-- 
GitLab