From b280ec3aa56c2e05f00d1675625a1dd0b6d4e138 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 28 Jul 2023 08:10:47 +0200 Subject: [PATCH] [tests] Make naming more explicit --- tests/{test_in.py => test_indian.py} | 0 tests/test_mc.py | 192 ------------------------- tests/{test_ch.py => test_shenzhen.py} | 0 3 files changed, 192 deletions(-) rename tests/{test_in.py => test_indian.py} (100%) delete mode 100644 tests/test_mc.py rename tests/{test_ch.py => test_shenzhen.py} (100%) diff --git a/tests/test_in.py b/tests/test_indian.py similarity index 100% rename from tests/test_in.py rename to tests/test_indian.py diff --git a/tests/test_mc.py b/tests/test_mc.py deleted file mode 100644 index 4d1a5a9e..00000000 --- a/tests/test_mc.py +++ /dev/null @@ -1,192 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Tests for Montgomery dataset.""" - -import importlib - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - # Default protocol - datamodule = importlib.import_module( - "ptbench.data.montgomery.datamodules.default" - ).datamodule - - subset = datamodule.splits - - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 88 - for s in subset["train"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - assert "validation" in subset - assert len(subset["validation"]) == 22 - for s in subset["validation"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - assert "test" in subset - assert len(subset["test"]) == 28 - for s in subset["test"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - # Check labels - for s in subset["train"]: - assert s[1] in [0.0, 1.0] - - for s in subset["validation"]: - assert s[1] in [0.0, 1.0] - - for s in subset["test"]: - assert s[1] in [0.0, 1.0] - - # Cross-validation fold 0-7 - for f in range(8): - datamodule = importlib.import_module( - f"ptbench.data.montgomery.datamodules.fold_{str(f)}" - ).datamodule - subset = datamodule.database_split - - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 99 - for s in subset["train"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - assert "validation" in subset - assert len(subset["validation"]) == 25 - for s in subset["validation"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - assert "test" in subset - assert len(subset["test"]) == 14 - for s in subset["test"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - # Check labels - for s in subset["train"]: - assert s[1] in [0.0, 1.0] - - for s in subset["validation"]: - assert s[1] in [0.0, 1.0] - - for s in subset["test"]: - assert s[1] in [0.0, 1.0] - - # Cross-validation fold 8-9 - for f in range(8, 10): - datamodule = importlib.import_module( - f"ptbench.data.montgomery.datamodules.fold_{str(f)}" - ).datamodule - subset = datamodule.database_split - - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 100 - for s in subset["train"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - assert "validation" in subset - assert len(subset["validation"]) == 25 - for s in subset["validation"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - assert "test" in subset - assert len(subset["test"]) == 13 - for s in subset["test"]: - assert s[0].startswith("CXR_png/MCUCXR_0") - - # Check labels - for s in subset["train"]: - assert s[1] in [0.0, 1.0] - - for s in subset["validation"]: - assert s[1] in [0.0, 1.0] - - for s in subset["test"]: - assert s[1] in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_loading(): - import torch - import torchvision.transforms - - from ptbench.data.datamodule import _DelayedLoadingDataset - - def _check_sample(s): - assert len(s) == 2 - - data = s[0] - metadata = s[1] - - assert isinstance(data, torch.Tensor) - - assert data.size(0) == 1 # check single channel - assert data.size(1) == data.size(2) # check square image - - assert ( - torchvision.transforms.ToPILImage()(data).mode == "L" - ) # Check colors - - assert "label" in metadata - assert metadata["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - datamodule = importlib.import_module( - "ptbench.data.montgomery.datamodules.default" - ).datamodule - subset = datamodule.database_split - raw_data_loader = datamodule.raw_data_loader - - # Need to use private function so we can limit the number of samples to use - dataset = _DelayedLoadingDataset(subset["train"][:limit], raw_data_loader) - - for s in dataset: - _check_sample(s) - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_check(): - from ptbench.data.split import check_database_split_loading - - limit = 30 # use this to limit testing to first images only, else 0 - - # Default protocol - datamodule = importlib.import_module( - "ptbench.data.montgomery.datamodules.default" - ).datamodule - database_split = datamodule.database_split - raw_data_loader = datamodule.raw_data_loader - - assert ( - check_database_split_loading( - database_split, raw_data_loader, limit=limit - ) - == 0 - ) - - # Folds - for f in range(10): - datamodule = importlib.import_module( - f"ptbench.data.montgomery.datamodules.fold_{f}" - ).datamodule - database_split = datamodule.database_split - raw_data_loader = datamodule.raw_data_loader - - assert ( - check_database_split_loading( - database_split, raw_data_loader, limit=limit - ) - == 0 - ) diff --git a/tests/test_ch.py b/tests/test_shenzhen.py similarity index 100% rename from tests/test_ch.py rename to tests/test_shenzhen.py -- GitLab