From 0b81d69f96ba6b72e827d1036976953304dc6234 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 22 Feb 2024 15:39:20 +0100
Subject: [PATCH] [test] Add tests for database checking

---
 tests/test_hivtb.py                            | 13 +++++++++++++
 tests/test_indian.py                           | 13 +++++++++++++
 tests/test_montgomery.py                       | 13 +++++++++++++
 tests/test_montgomery_shenzhen.py              | 14 ++++++++++++++
 tests/test_montgomery_shenzhen_indian.py       | 15 +++++++++++++++
 ...test_montgomery_shenzhen_indian_padchest.py | 16 ++++++++++++++++
 .../test_montgomery_shenzhen_indian_tbx11k.py  | 16 ++++++++++++++++
 tests/test_nih_cxr14.py                        | 13 +++++++++++++
 tests/test_nih_cxr14_padchest.py               | 14 ++++++++++++++
 tests/test_padchest.py                         | 13 +++++++++++++
 tests/test_shenzhen.py                         | 13 +++++++++++++
 tests/test_tbpoc.py                            | 13 +++++++++++++
 tests/test_tbx11k.py                           | 18 ++++++++++++++++++
 13 files changed, 184 insertions(+)

diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py
index 44c1cca3..efc11f40 100644
--- a/tests/test_hivtb.py
+++ b/tests/test_hivtb.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -43,6 +45,17 @@ def test_protocol_consistency(
     )
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["--limit=10", "hivtb-f0"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
 @pytest.mark.parametrize(
     "dataset",
diff --git a/tests/test_indian.py b/tests/test_indian.py
index 3a959e60..5b76f243 100644
--- a/tests/test_indian.py
+++ b/tests/test_indian.py
@@ -10,6 +10,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -47,6 +49,17 @@ def test_protocol_consistency(
     )
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["indian"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.indian")
 @pytest.mark.parametrize(
     "dataset",
diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py
index 4ce62580..42b2b9b7 100644
--- a/tests/test_montgomery.py
+++ b/tests/test_montgomery.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -44,6 +46,17 @@ def test_protocol_consistency(
     )
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["montgomery"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
 @pytest.mark.parametrize(
     "dataset",
diff --git a/tests/test_montgomery_shenzhen.py b/tests/test_montgomery_shenzhen.py
index 7f711905..8bd70932 100644
--- a/tests/test_montgomery_shenzhen.py
+++ b/tests/test_montgomery_shenzhen.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 @pytest.mark.parametrize(
     "name",
@@ -53,3 +55,15 @@ def test_split_consistency(name: str):
         assert shenzhen.splits[split][0][0] == combined.splits[split][1][0]
         assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader)
         assert isinstance(combined.splits[split][1][1], ShenzhenLoader)
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["montgomery-shenzhen"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
diff --git a/tests/test_montgomery_shenzhen_indian.py b/tests/test_montgomery_shenzhen_indian.py
index 3134574e..7534f299 100644
--- a/tests/test_montgomery_shenzhen_indian.py
+++ b/tests/test_montgomery_shenzhen_indian.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 @pytest.mark.parametrize(
     "name",
@@ -65,3 +67,16 @@ def test_split_consistency(name: str):
         assert indian.splits[split][0][0] == combined.splits[split][2][0]
         assert isinstance(indian.splits[split][0][1], IndianLoader)
         assert isinstance(combined.splits[split][2][1], IndianLoader)
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
+@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["montgomery-shenzhen-indian"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
diff --git a/tests/test_montgomery_shenzhen_indian_padchest.py b/tests/test_montgomery_shenzhen_indian_padchest.py
index 84b0aca9..b70d9c25 100644
--- a/tests/test_montgomery_shenzhen_indian_padchest.py
+++ b/tests/test_montgomery_shenzhen_indian_padchest.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 @pytest.mark.parametrize(
     "name,padchest_name",
@@ -69,3 +71,17 @@ def test_split_consistency(name: str, padchest_name: str):
             assert padchest.splits[split][0][0] == combined.splits[split][3][0]
             assert isinstance(padchest.splits[split][0][1], PadChestLoader)
             assert isinstance(combined.splits[split][3][1], PadChestLoader)
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
+@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
+@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["montgomery-shenzhen-indian-padchest"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
diff --git a/tests/test_montgomery_shenzhen_indian_tbx11k.py b/tests/test_montgomery_shenzhen_indian_tbx11k.py
index d38cd77a..e8f167f1 100644
--- a/tests/test_montgomery_shenzhen_indian_tbx11k.py
+++ b/tests/test_montgomery_shenzhen_indian_tbx11k.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 @pytest.mark.parametrize(
     "name,tbx11k_name",
@@ -89,3 +91,17 @@ def test_split_consistency(name: str, tbx11k_name: str):
         assert tbx11k.splits[split][0][0] == combined.splits[split][3][0]
         assert isinstance(tbx11k.splits[split][0][1], TBX11kLoader)
         assert isinstance(combined.splits[split][3][1], TBX11kLoader)
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
+@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
+@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
+@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["montgomery-shenzhen-indian-tbx11k-v1"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py
index 1790dbc7..8a16fec9 100644
--- a/tests/test_nih_cxr14.py
+++ b/tests/test_nih_cxr14.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -44,6 +46,17 @@ testdata = [
 ]
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["--limit=10", "nih-cxr14"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
 @pytest.mark.parametrize("name,dataset,num_labels", testdata)
 def test_loading(database_checkers, name: str, dataset: str, num_labels: int):
diff --git a/tests/test_nih_cxr14_padchest.py b/tests/test_nih_cxr14_padchest.py
index 2c013398..d98e096d 100644
--- a/tests/test_nih_cxr14_padchest.py
+++ b/tests/test_nih_cxr14_padchest.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 @pytest.mark.parametrize(
     "name,padchest_name,combined_name",
@@ -45,3 +47,15 @@ def test_split_consistency(name: str, padchest_name: str, combined_name: str):
             assert padchest.splits[split][0][0] == combined.splits[split][1][0]
             assert isinstance(padchest.splits[split][0][1], PadChestLoader)
             assert isinstance(combined.splits[split][1][1], PadChestLoader)
+
+
+@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14")
+@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["--limit=10", "nih-cxr14-padchest"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
diff --git a/tests/test_padchest.py b/tests/test_padchest.py
index 68e24077..262b32c0 100644
--- a/tests/test_padchest.py
+++ b/tests/test_padchest.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -40,6 +42,17 @@ def test_protocol_consistency(
     )
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["--limit=10", "padchest-idiap"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+
 testdata = [
     ("idiap", "train", 193),
     ("idiap", "test", 1),
diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py
index 42b23ce1..3c5fc661 100644
--- a/tests/test_shenzhen.py
+++ b/tests/test_shenzhen.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -44,6 +46,17 @@ def test_protocol_consistency(
     )
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["shenzhen"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
 @pytest.mark.parametrize(
     "dataset",
diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py
index 44e74278..c5d37bbb 100644
--- a/tests/test_tbpoc.py
+++ b/tests/test_tbpoc.py
@@ -7,6 +7,8 @@ import importlib
 
 import pytest
 
+from click.testing import CliRunner
+
 
 def id_function(val):
     if isinstance(val, dict):
@@ -46,6 +48,17 @@ def test_protocol_consistency(
     )
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["tbpoc-f0"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
 @pytest.mark.parametrize(
     "dataset",
diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py
index 5d1f584c..96bc79e4 100644
--- a/tests/test_tbx11k.py
+++ b/tests/test_tbx11k.py
@@ -9,6 +9,8 @@ import typing
 import pytest
 import torch
 
+from click.testing import CliRunner
+
 
 def id_function(val):
     if isinstance(val, (dict, tuple)):
@@ -231,6 +233,22 @@ def check_loaded_batch(
     # __import__("pdb").set_trace()
 
 
+@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
+def test_database_check():
+    from mednet.scripts.database import check
+
+    runner = CliRunner()
+    result = runner.invoke(check, ["--limit=10", "tbx11k-v1-f0"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+    result = runner.invoke(check, ["--limit=10", "tbx11k-v2-f0"])
+    assert (
+        result.exit_code == 0
+    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
+
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
 @pytest.mark.parametrize(
     "dataset",
-- 
GitLab