From 1e1e2daccc79f51f3647055ef60c8a14cb47d743 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Sat, 5 Aug 2023 21:27:14 +0200
Subject: [PATCH] [tests] Minimal tests for the combined NIH-CXR14 and PadChest
 database remix

---
 tests/test_nih_cxr14_padchest.py | 55 +++++++++++++++++++++-----------
 1 file changed, 36 insertions(+), 19 deletions(-)

diff --git a/tests/test_nih_cxr14_padchest.py b/tests/test_nih_cxr14_padchest.py
index 6199ad54..4a9379e3 100644
--- a/tests/test_nih_cxr14_padchest.py
+++ b/tests/test_nih_cxr14_padchest.py
@@ -1,30 +1,47 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-"""Tests for the aggregated NIH CXR14-PadChest dataset."""
+"""Tests for the aggregated NIH+CXR14-PadChest dataset."""
+
+import importlib
 
 import pytest
 
 
-@pytest.mark.skip(reason="Test need to be updated")
-def test_dataset_consistency():
-    from ptbench.data.nih_cxr14 import default as nih
-    from ptbench.data.nih_cxr14_padchest import idiap as nih_pc
-    from ptbench.data.padchest import no_tb_idiap as pc
+@pytest.mark.parametrize(
+    "name,padchest_name,combined_name",
+    [
+        ("default", "no_tb_idiap", "idiap"),
+    ],
+)
+def test_split_consistency(name: str, padchest_name: str, combined_name: str):
+    nih_cxr14 = importlib.import_module(
+        f".{name}", "ptbench.data.nih_cxr14"
+    ).datamodule
+
+    padchest = importlib.import_module(
+        f".{padchest_name}", "ptbench.data.padchest"
+    ).datamodule
+
+    combined = importlib.import_module(
+        f".{combined_name}", "ptbench.data.nih_cxr14_padchest"
+    ).datamodule
 
-    # Default protocol
-    nih_pc_dataset = nih_pc.dataset
-    assert isinstance(nih_pc_dataset, dict)
+    CXR14Loader = importlib.import_module(
+        ".datamodule", "ptbench.data.nih_cxr14"
+    ).RawDataLoader
 
-    nih_dataset = nih.dataset
-    pc_dataset = pc.dataset
+    PadChestLoader = importlib.import_module(
+        ".datamodule", "ptbench.data.padchest"
+    ).RawDataLoader
 
-    assert "train" in nih_pc_dataset
-    assert len(nih_pc_dataset["train"]) == len(nih_dataset["train"]) + len(
-        pc_dataset["train"]
-    )
+    for split in ("train", "validation", "test"):
+        assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0]
+        assert isinstance(nih_cxr14.splits[split][0][1], CXR14Loader)
+        assert isinstance(combined.splits[split][0][1], CXR14Loader)
 
-    assert "validation" in nih_pc_dataset
-    assert len(nih_pc_dataset["validation"]) == len(
-        nih_dataset["validation"]
-    ) + len(pc_dataset["validation"])
+        if split != "test":
+            # padchest has no test
+            assert padchest.splits[split][0][0] == combined.splits[split][1][0]
+            assert isinstance(padchest.splits[split][0][1], PadChestLoader)
+            assert isinstance(combined.splits[split][1][1], PadChestLoader)
-- 
GitLab