-
André Anjos authoredAndré Anjos authored
test_montgomery_shenzhen_indian_padchest.py 3.20 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for the aggregated Montgomery-Shenzhen-Indian-PadChest dataset."""
import importlib
import pytest
from click.testing import CliRunner
@pytest.mark.parametrize(
"name,padchest_name",
[
("default", "tb_idiap"),
],
)
def test_split_consistency(name: str, padchest_name: str):
montgomery = importlib.import_module(
f".{name}",
"mednet.classify.config.data.montgomery",
).datamodule
shenzhen = importlib.import_module(
f".{name}",
"mednet.classify.config.data.shenzhen",
).datamodule
indian = importlib.import_module(
f".{name}",
"mednet.classify.config.data.indian",
).datamodule
padchest = importlib.import_module(
f".{padchest_name}",
"mednet.classify.config.data.padchest",
).datamodule
combined = importlib.import_module(
f".{name}",
"mednet.classify.config.data.montgomery_shenzhen_indian_padchest",
).datamodule
montgomery_loader = importlib.import_module(
".datamodule",
"mednet.classify.config.data.montgomery",
).ClassificationRawDataLoader
shenzhen_loader = importlib.import_module(
".datamodule",
"mednet.classify.config.data.shenzhen",
).ClassificationRawDataLoader
indian_loader = importlib.import_module(
".datamodule",
"mednet.classify.config.data.indian",
).ClassificationRawDataLoader
padchest_loader = importlib.import_module(
".datamodule",
"mednet.classify.config.data.padchest",
).ClassificationRawDataLoader
for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
assert isinstance(montgomery.splits[split][0][1], montgomery_loader)
assert isinstance(combined.splits[split][0][1], montgomery_loader)
assert shenzhen.splits[split][0][0] == combined.splits[split][1][0]
assert isinstance(shenzhen.splits[split][0][1], shenzhen_loader)
assert isinstance(combined.splits[split][1][1], shenzhen_loader)
assert indian.splits[split][0][0] == combined.splits[split][2][0]
assert isinstance(indian.splits[split][0][1], indian_loader)
assert isinstance(combined.splits[split][2][1], indian_loader)
if split != "validation":
# padchest has no validation
assert padchest.splits[split][0][0] == combined.splits[split][3][0]
assert isinstance(padchest.splits[split][0][1], padchest_loader)
assert isinstance(combined.splits[split][3][1], padchest_loader)
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
def test_database_check():
from mednet.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["montgomery-shenzhen-indian-padchest"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"