Skip to content
Snippets Groups Projects
Commit fc88dade authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Implement tests for montgomery_shenzhen_indian_padchest

parent 075d3ce8
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #76783 passed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for the aggregated Montgomery-Shenzhen-Indian-Padchest dataset."""
import pytest
@pytest.mark.skip(reason="Test need to be updated")
def test_dataset_consistency():
from ptbench.data.indian import default as indian
from ptbench.data.montgomery import default as mc
from ptbench.data.montgomery_shenzhen_indian_padchest import (
default as mc_ch_in_pc,
)
from ptbench.data.padchest import tb_idiap as pc
from ptbench.data.shenzhen import default as ch
# Default protocol
mc_ch_in_pc_dataset = mc_ch_in_pc.dataset
assert isinstance(mc_ch_in_pc_dataset, dict)
mc_dataset = mc.dataset
ch_dataset = ch.dataset
in_dataset = indian.dataset
pc_dataset = pc.dataset
assert "train" in mc_ch_in_pc_dataset
assert len(mc_ch_in_pc_dataset["train"]) == len(mc_dataset["train"]) + len(
ch_dataset["train"]
) + len(in_dataset["train"]) + len(pc_dataset["train"])
assert "test" in mc_ch_in_pc_dataset
assert len(mc_ch_in_pc_dataset["test"]) == len(mc_dataset["test"]) + len(
ch_dataset["test"]
) + len(in_dataset["test"]) + len(pc_dataset["test"])
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for the aggregated Montgomery-Shenzhen-Indian-PadChest dataset."""
import importlib
import pytest
@pytest.mark.parametrize(
"name,padchest_name",
[
("default", "tb_idiap"),
],
)
def test_split_consistency(name: str, padchest_name: str):
montgomery = importlib.import_module(
f".{name}", "ptbench.data.montgomery"
).datamodule
shenzhen = importlib.import_module(
f".{name}", "ptbench.data.shenzhen"
).datamodule
indian = importlib.import_module(
f".{name}", "ptbench.data.indian"
).datamodule
padchest = importlib.import_module(
f".{padchest_name}", "ptbench.data.padchest"
).datamodule
combined = importlib.import_module(
f".{name}", "ptbench.data.montgomery_shenzhen_indian_padchest"
).datamodule
MontgomeryLoader = importlib.import_module(
".datamodule", "ptbench.data.montgomery"
).RawDataLoader
ShenzhenLoader = importlib.import_module(
".datamodule", "ptbench.data.shenzhen"
).RawDataLoader
IndianLoader = importlib.import_module(
".datamodule", "ptbench.data.indian"
).RawDataLoader
PadChestLoader = importlib.import_module(
".datamodule", "ptbench.data.padchest"
).RawDataLoader
for split in ("train", "validation", "test"):
assert montgomery.splits[split][0][0] == combined.splits[split][0][0]
assert isinstance(montgomery.splits[split][0][1], MontgomeryLoader)
assert isinstance(combined.splits[split][0][1], MontgomeryLoader)
assert shenzhen.splits[split][0][0] == combined.splits[split][1][0]
assert isinstance(shenzhen.splits[split][0][1], ShenzhenLoader)
assert isinstance(combined.splits[split][1][1], ShenzhenLoader)
assert indian.splits[split][0][0] == combined.splits[split][2][0]
assert isinstance(indian.splits[split][0][1], IndianLoader)
assert isinstance(combined.splits[split][2][1], IndianLoader)
if split != "validation":
# padchest has no validation
assert padchest.splits[split][0][0] == combined.splits[split][3][0]
assert isinstance(padchest.splits[split][0][1], PadChestLoader)
assert isinstance(combined.splits[split][3][1], PadChestLoader)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment