diff --git a/tests/test_mc_ch_in_pc.py b/tests/test_mc_ch_in_pc.py deleted file mode 100644 index a1e4fc121fd3c5ae73aca7ac8fdde0abec6959a6..0000000000000000000000000000000000000000 --- a/tests/test_mc_ch_in_pc.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for the aggregated Montgomery-Shenzhen-Indian-Padchest dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_dataset_consistency(): - from ptbench.data.indian import default as indian - from ptbench.data.montgomery import default as mc - from ptbench.data.montgomery_shenzhen_indian_padchest import ( - default as mc_ch_in_pc, - ) - from ptbench.data.padchest import tb_idiap as pc - from ptbench.data.shenzhen import default as ch - - # Default protocol - mc_ch_in_pc_dataset = mc_ch_in_pc.dataset - assert isinstance(mc_ch_in_pc_dataset, dict) - - mc_dataset = mc.dataset - ch_dataset = ch.dataset - in_dataset = indian.dataset - pc_dataset = pc.dataset - - assert "train" in mc_ch_in_pc_dataset - assert len(mc_ch_in_pc_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) + len(pc_dataset["train"]) - - assert "test" in mc_ch_in_pc_dataset - assert len(mc_ch_in_pc_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) + len(pc_dataset["test"]) diff --git a/tests/test_montgomery_shenzhen_indian_padchest.py b/tests/test_montgomery_shenzhen_indian_padchest.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1f7be4bfc0cc2779eccca0f16a72c338691043 --- /dev/null +++ b/tests/test_montgomery_shenzhen_indian_padchest.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Tests for the aggregated Montgomery-Shenzhen-Indian-PadChest dataset.""" + +import importlib + +import pytest + + +@pytest.mark.parametrize( + "name,padchest_name", + [ + ("default", "tb_idiap"), + ], +) +def test_split_consistency(name: str, padchest_name: str): + montgomery = importlib.import_module( + f".{name}", "ptbench.data.montgomery" + ).datamodule + + shenzhen = importlib.import_module( + f".{name}", "ptbench.data.shenzhen" + ).datamodule + + indian = importlib.import_module( + f".{name}", "ptbench.data.indian" + ).datamodule + + padchest = importlib.import_module( + f".{padchest_name}", "ptbench.data.padchest" + ).datamodule + + combined = importlib.import_module( + f".{name}", "ptbench.data.montgomery_shenzhen_indian_padchest" + ).datamodule + + MontgomeryLoader = importlib.import_module( + ".datamodule", "ptbench.data.montgomery" + ).RawDataLoader + + ShenzhenLoader = importlib.import_module( + ".datamodule", "ptbench.data.shenzhen" + ).RawDataLoader + + IndianLoader = importlib.import_module( + ".datamodule", "ptbench.data.indian" + ).RawDataLoader + + PadChestLoader = importlib.import_module( + ".datamodule", "ptbench.data.padchest" + ).RawDataLoader + + for split in ("train", "validation", "test"): + assert montgomery.splits[split][0][0] == combined.splits[split][0][0] + assert isinstance(montgomery.splits[split][0][1], MontgomeryLoader) + assert isinstance(combined.splits[split][0][1], MontgomeryLoader) + + assert shenzhen.splits[split][0][0] == combined.splits[split][1][0] + assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader) + assert isinstance(combined.splits[split][1][1], ShenzhenLoader) + + assert indian.splits[split][0][0] == combined.splits[split][2][0] + assert isinstance(indian.splits[split][0][1], IndianLoader) + assert isinstance(combined.splits[split][2][1], IndianLoader) + + if split != "validation": + # padchest has no validation + assert padchest.splits[split][0][0] == combined.splits[split][3][0] + assert isinstance(padchest.splits[split][0][1], PadChestLoader) + assert isinstance(combined.splits[split][3][1], PadChestLoader)