Skip to content
Snippets Groups Projects
test_nih_cxr14_padchest.py 1.55 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for the aggregated NIH+CXR14-PadChest dataset."""

import importlib

import pytest


@pytest.mark.parametrize(
    "name,padchest_name,combined_name",
    [
        ("default", "no_tb_idiap", "idiap"),
    ],
)
def test_split_consistency(name: str, padchest_name: str, combined_name: str):
    nih_cxr14 = importlib.import_module(
        f".{name}", "mednet.config.data.nih_cxr14"
    ).datamodule

    padchest = importlib.import_module(
        f".{padchest_name}", "mednet.config.data.padchest"
    ).datamodule

    combined = importlib.import_module(
        f".{combined_name}", "mednet.config.data.nih_cxr14_padchest"
    ).datamodule

    CXR14Loader = importlib.import_module(
        ".datamodule", "mednet.config.data.nih_cxr14"
    ).RawDataLoader

    PadChestLoader = importlib.import_module(
        ".datamodule", "mednet.config.data.padchest"
    ).RawDataLoader

    for split in ("train", "validation", "test"):
        assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0]
        assert isinstance(nih_cxr14.splits[split][0][1], CXR14Loader)
        assert isinstance(combined.splits[split][0][1], CXR14Loader)

        if split != "test":
            # padchest has no test
            assert padchest.splits[split][0][0] == combined.splits[split][1][0]
            assert isinstance(padchest.splits[split][0][1], PadChestLoader)
            assert isinstance(combined.splits[split][1][1], PadChestLoader)