diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py index 676fa8ef96d26b794341dee5c5cd09caf55d06d0..92617fddc7f93608d22297d173abf47df0a46a2b 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py @@ -4,6 +4,7 @@ """Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets.""" from ....data.datamodule import ConcatDataModule +from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import make_split as make_indian_split from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader @@ -26,7 +27,7 @@ class DataModule(ConcatDataModule): montgomery_split = make_montgomery_split(split_filename) shenzhen_loader = ShenzhenLoader() shenzhen_split = make_shenzhen_split(split_filename) - indian_loader = IndianLoader() + indian_loader = IndianLoader(INDIAN_KEY_DATADIR) indian_split = make_indian_split(split_filename) super().__init__( diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py index 2876af8f0335acd87313c02ca18ad8b47dd21bcc..7cff19e0efb6b9e4417992a866b2da315f81845e 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py @@ -4,6 +4,7 @@ """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets.""" from ....data.datamodule import ConcatDataModule +from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import make_split as make_indian_split from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader @@ -31,7 +32,7 @@ class DataModule(ConcatDataModule): montgomery_split = make_montgomery_split(split_filename) shenzhen_loader = ShenzhenLoader() shenzhen_split = make_shenzhen_split(split_filename) - indian_loader = IndianLoader() + indian_loader = IndianLoader(INDIAN_KEY_DATADIR) indian_split = make_indian_split(split_filename) padchest_loader = PadchestLoader() padchest_split = make_padchest_split(padchest_split_filename) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index 8dd831981465af7afa276d67287de5c09008bc36..358648d83f0f5cf824604950d7221b779803e3c0 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -4,6 +4,7 @@ """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets.""" from ....data.datamodule import ConcatDataModule +from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import make_split as make_indian_split from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader @@ -31,7 +32,7 @@ class DataModule(ConcatDataModule): montgomery_split = make_montgomery_split(split_filename) shenzhen_loader = ShenzhenLoader() shenzhen_split = make_shenzhen_split(split_filename) - indian_loader = IndianLoader() + indian_loader = IndianLoader(INDIAN_KEY_DATADIR) indian_split = make_indian_split(split_filename) tbx11k_loader = TBX11kLoader() tbx11k_split = make_tbx11k_split(tbx11k_split_filename) diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index 44c1cca3f702dc8776395c2702bc7c0b6533f105..efc11f40650eda91ec0d47a26f94214a33fc3419 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + def id_function(val): if isinstance(val, dict): @@ -43,6 +45,17 @@ def test_protocol_consistency( ) +@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["--limit=10", "hivtb-f0"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") @pytest.mark.parametrize( "dataset", diff --git a/tests/test_indian.py b/tests/test_indian.py index 3a959e60783c4f6534c55352b0b933095ed9c093..5b76f2439d01c0057a1b184ecb37ac1e83733031 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -10,6 +10,8 @@ import importlib import pytest +from click.testing import CliRunner + def id_function(val): if isinstance(val, dict): @@ -47,6 +49,17 @@ def test_protocol_consistency( ) +@pytest.mark.skip_if_rc_var_not_set("datadir.indian") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["indian"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + @pytest.mark.skip_if_rc_var_not_set("datadir.indian") @pytest.mark.parametrize( "dataset", diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 4ce6258064a64507ccdc2714b9aa16c957961a36..42b2b9b72f986ca8d8e38e1d19e97bd8a7b89be8 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + def id_function(val): if isinstance(val, dict): @@ -44,6 +46,17 @@ def test_protocol_consistency( ) +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["montgomery"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.parametrize( "dataset", diff --git a/tests/test_montgomery_shenzhen.py b/tests/test_montgomery_shenzhen.py index 7f7119054299e7552c567d9df7cb45c7f8d9d028..8bd7093229d2e104f6e2ccde0f589d04cb5abba1 100644 --- a/tests/test_montgomery_shenzhen.py +++ b/tests/test_montgomery_shenzhen.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + @pytest.mark.parametrize( "name", @@ -53,3 +55,15 @@ def test_split_consistency(name: str): assert shenzhen.splits[split][0][0] == combined.splits[split][1][0] assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader) assert isinstance(combined.splits[split][1][1], ShenzhenLoader) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["montgomery-shenzhen"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" diff --git a/tests/test_montgomery_shenzhen_indian.py b/tests/test_montgomery_shenzhen_indian.py index 3134574e5b665cceeb48f80bf5806c436af5d44b..7534f2995226e913222a414a7216d40bc83659b2 100644 --- a/tests/test_montgomery_shenzhen_indian.py +++ b/tests/test_montgomery_shenzhen_indian.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + @pytest.mark.parametrize( "name", @@ -65,3 +67,16 @@ def test_split_consistency(name: str): assert indian.splits[split][0][0] == combined.splits[split][2][0] assert isinstance(indian.splits[split][0][1], IndianLoader) assert isinstance(combined.splits[split][2][1], IndianLoader) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +@pytest.mark.skip_if_rc_var_not_set("datadir.indian") +@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["montgomery-shenzhen-indian"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" diff --git a/tests/test_montgomery_shenzhen_indian_padchest.py b/tests/test_montgomery_shenzhen_indian_padchest.py index 84b0aca9eff04d2d5f54e1236ae1028df2204e3c..b70d9c2532ab31954dd0d7f9d9b5f701a5c582ce 100644 --- a/tests/test_montgomery_shenzhen_indian_padchest.py +++ b/tests/test_montgomery_shenzhen_indian_padchest.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + @pytest.mark.parametrize( "name,padchest_name", @@ -69,3 +71,17 @@ def test_split_consistency(name: str, padchest_name: str): assert padchest.splits[split][0][0] == combined.splits[split][3][0] assert isinstance(padchest.splits[split][0][1], PadChestLoader) assert isinstance(combined.splits[split][3][1], PadChestLoader) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") +@pytest.mark.skip_if_rc_var_not_set("datadir.indian") +@pytest.mark.skip_if_rc_var_not_set("datadir.padchest") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["montgomery-shenzhen-indian-padchest"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" diff --git a/tests/test_montgomery_shenzhen_indian_tbx11k.py b/tests/test_montgomery_shenzhen_indian_tbx11k.py index d38cd77a2e845bb2d863e26f5e9517935e545ca8..e8f167f13c21057b75129c2f0506d6072e617cf1 100644 --- a/tests/test_montgomery_shenzhen_indian_tbx11k.py +++ b/tests/test_montgomery_shenzhen_indian_tbx11k.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + @pytest.mark.parametrize( "name,tbx11k_name", @@ -89,3 +91,17 @@ def test_split_consistency(name: str, tbx11k_name: str): assert tbx11k.splits[split][0][0] == combined.splits[split][3][0] assert isinstance(tbx11k.splits[split][0][1], TBX11kLoader) assert isinstance(combined.splits[split][3][1], TBX11kLoader) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") +@pytest.mark.skip_if_rc_var_not_set("datadir.indian") +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["montgomery-shenzhen-indian-tbx11k-v1"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py index 1790dbc77007c77a320745251ea0b93b3ffcd3a8..8a16fec9ff19141b35e56e4d3b0e10c6b32e806e 100644 --- a/tests/test_nih_cxr14.py +++ b/tests/test_nih_cxr14.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + def id_function(val): if isinstance(val, dict): @@ -44,6 +46,17 @@ testdata = [ ] +@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["--limit=10", "nih-cxr14"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") @pytest.mark.parametrize("name,dataset,num_labels", testdata) def test_loading(database_checkers, name: str, dataset: str, num_labels: int): diff --git a/tests/test_nih_cxr14_padchest.py b/tests/test_nih_cxr14_padchest.py index 2c013398690a367b0c9c184f7710a05f345424e4..d98e096d4ed74fd3906137e2f1374e84ac1e9248 100644 --- a/tests/test_nih_cxr14_padchest.py +++ b/tests/test_nih_cxr14_padchest.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + @pytest.mark.parametrize( "name,padchest_name,combined_name", @@ -45,3 +47,15 @@ def test_split_consistency(name: str, padchest_name: str, combined_name: str): assert padchest.splits[split][0][0] == combined.splits[split][1][0] assert isinstance(padchest.splits[split][0][1], PadChestLoader) assert isinstance(combined.splits[split][1][1], PadChestLoader) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") +@pytest.mark.skip_if_rc_var_not_set("datadir.padchest") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["--limit=10", "nih-cxr14-padchest"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" diff --git a/tests/test_padchest.py b/tests/test_padchest.py index 68e24077725172230bc7ecd76acf367eb1ffde0d..262b32c0291daf926dc6940750b8e2d905bea464 100644 --- a/tests/test_padchest.py +++ b/tests/test_padchest.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + def id_function(val): if isinstance(val, dict): @@ -40,6 +42,17 @@ def test_protocol_consistency( ) +@pytest.mark.skip_if_rc_var_not_set("datadir.padchest") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["--limit=10", "padchest-idiap"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + testdata = [ ("idiap", "train", 193), ("idiap", "test", 1), diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index 42b23ce16b50d7303bdd1e14bfe6d0199c57623e..3c5fc66122483c2e0c3ac2b9f258ed70d17120a2 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + def id_function(val): if isinstance(val, dict): @@ -44,6 +46,17 @@ def test_protocol_consistency( ) +@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["shenzhen"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") @pytest.mark.parametrize( "dataset", diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py index 44e742789684bf823cccdaeca539e0ee0f2f0482..c5d37bbb6899bdf6b0d0dd8bc03190436c2ce184 100644 --- a/tests/test_tbpoc.py +++ b/tests/test_tbpoc.py @@ -7,6 +7,8 @@ import importlib import pytest +from click.testing import CliRunner + def id_function(val): if isinstance(val, dict): @@ -46,6 +48,17 @@ def test_protocol_consistency( ) +@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["tbpoc-f0"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + @pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") @pytest.mark.parametrize( "dataset", diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 5d1f584c1c79e0b420ec813d6dcc71d11ca3ecc1..96bc79e4994d7f26c98252255fb415f9709581b1 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -9,6 +9,8 @@ import typing import pytest import torch +from click.testing import CliRunner + def id_function(val): if isinstance(val, (dict, tuple)): @@ -231,6 +233,22 @@ def check_loaded_batch( # __import__("pdb").set_trace() +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") +def test_database_check(): + from mednet.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["--limit=10", "tbx11k-v1-f0"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + result = runner.invoke(check, ["--limit=10", "tbx11k-v2-f0"]) + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + @pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") @pytest.mark.parametrize( "dataset",