Skip to content
Snippets Groups Projects
test_montgomery_shenzhen_indian.py 2.00 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for the aggregated Montgomery-Shenzhen-Indian dataset."""

import importlib

import pytest


@pytest.mark.parametrize(
    "name",
    [
        "default",
        "fold_0",
        "fold_1",
        "fold_2",
        "fold_3",
        "fold_4",
        "fold_5",
        "fold_6",
        "fold_7",
        "fold_8",
        "fold_9",
    ],
)
def test_split_consistency(name: str):
    montgomery = importlib.import_module(
        f".{name}", "mednet.config.data.montgomery"
    ).datamodule

    shenzhen = importlib.import_module(
        f".{name}", "mednet.config.data.shenzhen"
    ).datamodule

    indian = importlib.import_module(
        f".{name}", "mednet.config.data.indian"
    ).datamodule

    combined = importlib.import_module(
        f".{name}", "mednet.config.data.montgomery_shenzhen_indian"
    ).datamodule

    MontgomeryLoader = importlib.import_module(
        ".datamodule", "mednet.config.data.montgomery"
    ).RawDataLoader

    ShenzhenLoader = importlib.import_module(
        ".datamodule", "mednet.config.data.shenzhen"
    ).RawDataLoader

    IndianLoader = importlib.import_module(
        ".datamodule", "mednet.config.data.indian"
    ).RawDataLoader

    for split in ("train", "validation", "test"):
        assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
        assert isinstance(montgomery.splits[split][0][1], MontgomeryLoader)
        assert isinstance(combined.splits[split][0][1], MontgomeryLoader)

        assert shenzhen.splits[split][0][0] == combined.splits[split][1][0]
        assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader)
        assert isinstance(combined.splits[split][1][1], ShenzhenLoader)

        assert indian.splits[split][0][0] == combined.splits[split][2][0]
        assert isinstance(indian.splits[split][0][1], IndianLoader)
        assert isinstance(combined.splits[split][2][1], IndianLoader)