# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """Tests for Extended Montgomery dataset.""" from ptbench.data.montgomery_RS import dataset def test_protocol_consistency(): # Default protocol subset = dataset.subsets("default") assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 88 assert "validation" in subset assert len(subset["validation"]) == 22 assert "test" in subset assert len(subset["test"]) == 28 for s in subset["test"]: assert s.key.startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: assert s.label in [0.0, 1.0] for s in subset["validation"]: assert s.label in [0.0, 1.0] for s in subset["test"]: assert s.label in [0.0, 1.0] # Cross-validation fold 0-7 for f in range(8): subset = dataset.subsets("fold_" + str(f)) assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 99 for s in subset["train"]: assert s.key.startswith("CXR_png/MCUCXR_0") assert "validation" in subset assert len(subset["validation"]) == 25 for s in subset["validation"]: assert s.key.startswith("CXR_png/MCUCXR_0") assert "test" in subset assert len(subset["test"]) == 14 for s in subset["test"]: assert s.key.startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: assert s.label in [0.0, 1.0] for s in subset["validation"]: assert s.label in [0.0, 1.0] for s in subset["test"]: assert s.label in [0.0, 1.0] # Cross-validation fold 8-9 for f in range(8, 10): subset = dataset.subsets("fold_" + str(f)) assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 100 for s in subset["train"]: assert s.key.startswith("CXR_png/MCUCXR_0") assert "validation" in subset assert len(subset["validation"]) == 25 for s in subset["validation"]: assert s.key.startswith("CXR_png/MCUCXR_0") assert "test" in subset assert len(subset["test"]) == 13 for s in subset["test"]: assert s.key.startswith("CXR_png/MCUCXR_0") # Check labels for s in subset["train"]: assert s.label in [0.0, 1.0] for s in subset["validation"]: assert s.label in [0.0, 1.0] for s in subset["test"]: assert s.label in [0.0, 1.0] def test_loading(): def _check_sample(s): data = s.data assert isinstance(data, dict) assert len(data) == 2 assert "data" in data assert len(data["data"]) == 14 # Check radiological signs assert "label" in data assert data["label"] in [0, 1] # Check labels limit = 30 # use this to limit testing to first images only, else None subset = dataset.subsets("default") for s in subset["train"][:limit]: _check_sample(s)