From 1e1e2daccc79f51f3647055ef60c8a14cb47d743 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Sat, 5 Aug 2023 21:27:14 +0200 Subject: [PATCH] [tests] Minimal tests for the combined NIH-CXR14 and PadChest database remix --- tests/test_nih_cxr14_padchest.py | 55 +++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/tests/test_nih_cxr14_padchest.py b/tests/test_nih_cxr14_padchest.py index 6199ad54..4a9379e3 100644 --- a/tests/test_nih_cxr14_padchest.py +++ b/tests/test_nih_cxr14_padchest.py @@ -1,30 +1,47 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for the aggregated NIH CXR14-PadChest dataset.""" +"""Tests for the aggregated NIH+CXR14-PadChest dataset.""" + +import importlib import pytest -@pytest.mark.skip(reason="Test need to be updated") -def test_dataset_consistency(): - from ptbench.data.nih_cxr14 import default as nih - from ptbench.data.nih_cxr14_padchest import idiap as nih_pc - from ptbench.data.padchest import no_tb_idiap as pc +@pytest.mark.parametrize( + "name,padchest_name,combined_name", + [ + ("default", "no_tb_idiap", "idiap"), + ], +) +def test_split_consistency(name: str, padchest_name: str, combined_name: str): + nih_cxr14 = importlib.import_module( + f".{name}", "ptbench.data.nih_cxr14" + ).datamodule + + padchest = importlib.import_module( + f".{padchest_name}", "ptbench.data.padchest" + ).datamodule + + combined = importlib.import_module( + f".{combined_name}", "ptbench.data.nih_cxr14_padchest" + ).datamodule - # Default protocol - nih_pc_dataset = nih_pc.dataset - assert isinstance(nih_pc_dataset, dict) + CXR14Loader = importlib.import_module( + ".datamodule", "ptbench.data.nih_cxr14" + ).RawDataLoader - nih_dataset = nih.dataset - pc_dataset = pc.dataset + PadChestLoader = importlib.import_module( + ".datamodule", "ptbench.data.padchest" + ).RawDataLoader - assert "train" in nih_pc_dataset - assert len(nih_pc_dataset["train"]) == len(nih_dataset["train"]) + len( - pc_dataset["train"] - ) + for split in ("train", "validation", "test"): + assert nih_cxr14.splits[split][0][0] == combined.splits[split][0][0] + assert isinstance(nih_cxr14.splits[split][0][1], CXR14Loader) + assert isinstance(combined.splits[split][0][1], CXR14Loader) - assert "validation" in nih_pc_dataset - assert len(nih_pc_dataset["validation"]) == len( - nih_dataset["validation"] - ) + len(pc_dataset["validation"]) + if split != "test": + # padchest has no test + assert padchest.splits[split][0][0] == combined.splits[split][1][0] + assert isinstance(padchest.splits[split][0][1], PadChestLoader) + assert isinstance(combined.splits[split][1][1], PadChestLoader) -- GitLab