-
André Anjos authoredAndré Anjos authored
test_montgomery_shenzhen_indian_tbx11k.py 3.64 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for the aggregated Montgomery-Shenzhen-Indian-TBX11k dataset."""
import importlib
import pytest
from click.testing import CliRunner
@pytest.mark.parametrize(
"name,tbx11k_name",
[
("default", "v1_healthy_vs_atb"),
("fold_0", "v1_fold_0"),
("fold_1", "v1_fold_1"),
("fold_2", "v1_fold_2"),
("fold_3", "v1_fold_3"),
("fold_4", "v1_fold_4"),
("fold_5", "v1_fold_5"),
("fold_6", "v1_fold_6"),
("fold_7", "v1_fold_7"),
("fold_8", "v1_fold_8"),
("fold_9", "v1_fold_9"),
("default", "v2_others_vs_atb"),
("fold_0", "v2_fold_0"),
("fold_1", "v2_fold_1"),
("fold_2", "v2_fold_2"),
("fold_3", "v2_fold_3"),
("fold_4", "v2_fold_4"),
("fold_5", "v2_fold_5"),
("fold_6", "v2_fold_6"),
("fold_7", "v2_fold_7"),
("fold_8", "v2_fold_8"),
("fold_9", "v2_fold_9"),
],
)
def test_split_consistency(name: str, tbx11k_name: str):
montgomery = importlib.import_module(
f".{name}",
"mednet.config.data.montgomery",
).datamodule
shenzhen = importlib.import_module(
f".{name}",
"mednet.config.data.shenzhen",
).datamodule
indian = importlib.import_module(
f".{name}",
"mednet.config.data.indian",
).datamodule
tbx11k = importlib.import_module(
f".{tbx11k_name}",
"mednet.config.data.tbx11k",
).datamodule
combined = importlib.import_module(
f".{tbx11k_name}",
"mednet.config.data.montgomery_shenzhen_indian_tbx11k",
).datamodule
montgomery_loader = importlib.import_module(
".datamodule",
"mednet.config.data.montgomery",
).RawDataLoader
shenzhen_loader = importlib.import_module(
".datamodule",
"mednet.config.data.shenzhen",
).RawDataLoader
indian_loader = importlib.import_module(
".datamodule",
"mednet.config.data.indian",
).RawDataLoader
tbx11k_loader = importlib.import_module(
".datamodule",
"mednet.config.data.tbx11k",
).RawDataLoader
for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
assert isinstance(montgomery.splits[split][0][1], montgomery_loader)
assert isinstance(combined.splits[split][0][1], montgomery_loader)
assert shenzhen.splits[split][0][0] == combined.splits[split][1][0]
assert isinstance(shenzhen.splits[split][0][1], shenzhen_loader)
assert isinstance(combined.splits[split][1][1], shenzhen_loader)
assert indian.splits[split][0][0] == combined.splits[split][2][0]
assert isinstance(indian.splits[split][0][1], indian_loader)
assert isinstance(combined.splits[split][2][1], indian_loader)
assert tbx11k.splits[split][0][0] == combined.splits[split][3][0]
assert isinstance(tbx11k.splits[split][0][1], tbx11k_loader)
assert isinstance(combined.splits[split][3][1], tbx11k_loader)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k")
def test_database_check():
from mednet.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["montgomery-shenzhen-indian-tbx11k-v1"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"