Skip to content
Snippets Groups Projects
test_montgomery_shenzhen_indian_padchest.py 3.20 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for the aggregated Montgomery-Shenzhen-Indian-PadChest dataset."""

import importlib

import pytest
from click.testing import CliRunner


@pytest.mark.parametrize(
    "name,padchest_name",
    [
        ("default", "tb_idiap"),
    ],
)
def test_split_consistency(name: str, padchest_name: str):
    montgomery = importlib.import_module(
        f".{name}",
        "mednet.classify.config.data.montgomery",
    ).datamodule

    shenzhen = importlib.import_module(
        f".{name}",
        "mednet.classify.config.data.shenzhen",
    ).datamodule

    indian = importlib.import_module(
        f".{name}",
        "mednet.classify.config.data.indian",
    ).datamodule

    padchest = importlib.import_module(
        f".{padchest_name}",
        "mednet.classify.config.data.padchest",
    ).datamodule

    combined = importlib.import_module(
        f".{name}",
        "mednet.classify.config.data.montgomery_shenzhen_indian_padchest",
    ).datamodule

    montgomery_loader = importlib.import_module(
        ".datamodule",
        "mednet.classify.config.data.montgomery",
    ).ClassificationRawDataLoader

    shenzhen_loader = importlib.import_module(
        ".datamodule",
        "mednet.classify.config.data.shenzhen",
    ).ClassificationRawDataLoader

    indian_loader = importlib.import_module(
        ".datamodule",
        "mednet.classify.config.data.indian",
    ).ClassificationRawDataLoader

    padchest_loader = importlib.import_module(
        ".datamodule",
        "mednet.classify.config.data.padchest",
    ).ClassificationRawDataLoader

    for split in ("train", "validation", "test"):
        assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
        assert isinstance(montgomery.splits[split][0][1], montgomery_loader)
        assert isinstance(combined.splits[split][0][1], montgomery_loader)

        assert shenzhen.splits[split][0][0] == combined.splits[split][1][0]
        assert isinstance(shenzhen.splits[split][0][1], shenzhen_loader)
        assert isinstance(combined.splits[split][1][1], shenzhen_loader)

        assert indian.splits[split][0][0] == combined.splits[split][2][0]
        assert isinstance(indian.splits[split][0][1], indian_loader)
        assert isinstance(combined.splits[split][2][1], indian_loader)

        if split != "validation":
            # padchest has no validation
            assert padchest.splits[split][0][0] == combined.splits[split][3][0]
            assert isinstance(padchest.splits[split][0][1], padchest_loader)
            assert isinstance(combined.splits[split][3][1], padchest_loader)


@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
@pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen")
@pytest.mark.skip_if_rc_var_not_set("datadir.indian")
@pytest.mark.skip_if_rc_var_not_set("datadir.padchest")
def test_database_check():
    from mednet.scripts.database import check

    runner = CliRunner()
    result = runner.invoke(check, ["montgomery-shenzhen-indian-padchest"])
    assert (
        result.exit_code == 0
    ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"