From 513bed9df5b13eed930d3d49592a26d0cdb54084 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 2 Aug 2023 21:02:33 +0200 Subject: [PATCH] [tests.test_nih_cxr14] Adapt to new datamodule syntax --- tests/test_nih_cxr14.py | 135 +++++++++++++------------ tests/{test_pc.py => test_padchest.py} | 0 2 files changed, 70 insertions(+), 65 deletions(-) rename tests/{test_pc.py => test_padchest.py} (100%) diff --git a/tests/test_nih_cxr14.py b/tests/test_nih_cxr14.py index d7cc149a..242ac7e9 100644 --- a/tests/test_nih_cxr14.py +++ b/tests/test_nih_cxr14.py @@ -1,72 +1,77 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for NIH CXR14 dataset.""" +"""Tests for NIH CXR-14 dataset.""" -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.nih_cxr14_re import dataset - - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 98637 - for s in subset["train"]: - assert s.key.startswith("images/000") - - assert "validation" in subset - assert len(subset["validation"]) == 6350 - for s in subset["validation"]: - assert s.key.startswith("images/000") - - assert "test" in subset - assert len(subset["test"]) == 4054 - for s in subset["test"]: - assert s.key.startswith("images/000") - - # Check labels - for s in subset["train"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] +import importlib - for s in subset["validation"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] - - for s in subset["test"]: - for element in list(set(s.label)): - assert element in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re") -def test_loading(): - from ptbench.data.nih_cxr14_re import dataset - - def _check_size(size): - if size == (1024, 1024): - return True - return False - - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert _check_size(data["data"].size) # Check size - assert data["data"].mode == "RGB" # Check colors - - assert "label" in data - assert len(data["label"]) == 14 # Check labels +import pytest - limit = 30 # use this to limit testing to first images only, else None - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) +def id_function(val): + if isinstance(val, dict): + return str(val) + return repr(val) + + +@pytest.mark.parametrize( + "split,lenghts", + [ + ("default", dict(train=98637, validation=6350, test=4054)), + ], + ids=id_function, # just changes how pytest prints it +) +def test_protocol_consistency( + database_checkers, split: str, lenghts: dict[str, int] +): + from ptbench.data.nih_cxr14.datamodule import make_split + + database_checkers.check_split( + make_split(f"{split}.json.bz2"), + lengths=lenghts, + prefixes=("images/000",), + possible_labels=(0, 1), + ) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14") +@pytest.mark.parametrize( + "dataset", + [ + "train", + "validation", + "test", + ], +) +@pytest.mark.parametrize( + "name", + [ + "default", + ], +) +def test_loading(database_checkers, name: str, dataset: str): + datamodule = importlib.import_module( + f".{name}", "ptbench.data.nih_cxr14" + ).datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.setup("predict") # sets up all datasets + + loader = datamodule.predict_dataloader()[dataset] + + limit = 3 # limit load checking + for batch in loader: + if limit == 0: + break + database_checkers.check_loaded_batch( + batch, + batch_size=1, + color_planes=1, + prefixes=("images/000",), + possible_labels=(0, 1), + ) + limit -= 1 + + +# TODO: check size 1024x1024 +# TODO: check there are 14 binary labels (0, 1) diff --git a/tests/test_pc.py b/tests/test_padchest.py similarity index 100% rename from tests/test_pc.py rename to tests/test_padchest.py -- GitLab