# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later """Tests for Extended Indian dataset.""" from ptbench.data.indian_RS import dataset def test_protocol_consistency(): # Default protocol subset = dataset.subsets("default") assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 83 assert "validation" in subset assert len(subset["validation"]) == 20 assert "test" in subset assert len(subset["test"]) == 52 # Check labels for s in subset["train"]: assert s.label in [0.0, 1.0] for s in subset["validation"]: assert s.label in [0.0, 1.0] for s in subset["test"]: assert s.label in [0.0, 1.0] # Cross-validation fold 0-4 for f in range(5): subset = dataset.subsets("fold_" + str(f)) assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 111 for s in subset["train"]: assert s.key.startswith("DatasetA") assert "validation" in subset assert len(subset["validation"]) == 28 for s in subset["validation"]: assert s.key.startswith("DatasetA") assert "test" in subset assert len(subset["test"]) == 16 for s in subset["test"]: assert s.key.startswith("DatasetA") # Check labels for s in subset["train"]: assert s.label in [0.0, 1.0] for s in subset["validation"]: assert s.label in [0.0, 1.0] for s in subset["test"]: assert s.label in [0.0, 1.0] # Cross-validation fold 5-9 for f in range(5, 10): subset = dataset.subsets("fold_" + str(f)) assert len(subset) == 3 assert "train" in subset assert len(subset["train"]) == 112 for s in subset["train"]: assert s.key.startswith("DatasetA") assert "validation" in subset assert len(subset["validation"]) == 28 for s in subset["validation"]: assert s.key.startswith("DatasetA") assert "test" in subset assert len(subset["test"]) == 15 for s in subset["test"]: assert s.key.startswith("DatasetA") # Check labels for s in subset["train"]: assert s.label in [0.0, 1.0] for s in subset["validation"]: assert s.label in [0.0, 1.0] for s in subset["test"]: assert s.label in [0.0, 1.0] def test_loading(): def _check_sample(s): data = s.data assert isinstance(data, dict) assert len(data) == 2 assert "data" in data assert len(data["data"]) == 14 # Check radiological signs assert "label" in data assert data["label"] in [0, 1] # Check labels limit = 30 # use this to limit testing to first images only, else None subset = dataset.subsets("default") for s in subset["train"][:limit]: _check_sample(s)