diff --git a/tests/conftest.py b/tests/conftest.py index 837a34f40eb6595e341fe355eed28107d676f0f8..75b802a252994df5e2b5793bcd2c3422f5b9ad15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,14 @@ import os import pathlib import tempfile +import typing import zipfile import pytest import tomli_w +import torch + +from ptbench.data.typing import DatabaseSplit @pytest.fixture @@ -113,3 +117,105 @@ def pytest_sessionstart(session: pytest.Session) -> None: # stash the newly created temporary directory so we can erase it when the key = pytest.StashKey[tempfile.TemporaryDirectory]() session.stash[key] = montgomery_tempdir + + +class DatabaseCheckers: + """Helpers for database tests.""" + + @staticmethod + def check_split( + split: DatabaseSplit, + lengths: dict[str, int], + prefixes: typing.Sequence[str], + possible_labels: typing.Sequence[int], + ): + """Runs a simple consistence check on the data split. + + Parameters + ---------- + + make_split + A database specific function that takes a split name and returns + the loaded database split. + + split_filename + This is the split we will check + + lenghts + A dictionary that contains keys matching those of the split (this will + be checked). The values of the dictionary should correspond to the + sizes of each of the datasets in the split. + + prefixes + Each file named in a split should start with at least one of these + prefixes. + + possible_labels + These are the list of possible labels contained in any split. + """ + + assert len(split) == len(lengths) + + for k in lengths.keys(): + # dataset must have been declared + assert k in split + + assert len(split[k]) == lengths[k] + for s in split[k]: + assert any([s[0].startswith(k) for k in prefixes]) + assert s[1] in possible_labels + + @staticmethod + def check_loaded_batch( + batch, + batch_size: int, + color_planes: int, + prefixes: typing.Sequence[str], + possible_labels: typing.Sequence[int], + ): + """Checks the consistence of an individual (loaded) batch. + + Parameters + ---------- + + batch + The loaded batch to be checked. + + size + The mini-batch size + + prefixes + Each file named in a split should start with at least one of these + prefixes. + + possible_labels + These are the list of possible labels contained in any split. + """ + + assert len(batch) == 2 # data, metadata + + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape[0] == batch_size # mini-batch size + assert batch[0].shape[1] == color_planes # grayscale images + assert batch[0].shape[2] == batch[0].shape[3] # image is square + + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == 2 # label and name + + assert "label" in batch[1] + assert all([k in possible_labels for k in batch[1]["label"]]) + + assert "name" in batch[1] + assert all( + [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] + ) + + # use the code below to view generated images + # from torchvision.transforms.functional import to_pil_image + # to_pil_image(batch[0][0]).show() + # __import__("pdb").set_trace() + + +@pytest.fixture +def database_checkers(): + return DatabaseCheckers diff --git a/tests/test_11k.py b/tests/test_11k.py deleted file mode 100644 index b08a249ca592334fef3816171fbcc11aa60ad19a..0000000000000000000000000000000000000000 --- a/tests/test_11k.py +++ /dev/null @@ -1,213 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for TBX11K simplified dataset split 1.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.tbx11k_simplified import dataset - - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 2767 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 706 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 957 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 0-9 - for f in range(10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 3177 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 810 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 443 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency_bbox(): - from ptbench.data.tbx11k_simplified import dataset_with_bboxes - - # Default protocol - subset = dataset_with_bboxes.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 2767 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 706 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 957 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Check bounding boxes - for s in subset["train"]: - assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - - # Cross-validation fold 0-9 - for f in range(10): - subset = dataset_with_bboxes.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 3177 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 810 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 443 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Check bounding boxes - for s in subset["train"]: - assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") -def test_loading(): - from ptbench.data.tbx11k_simplified import dataset - - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert data["data"].size == (512, 512) - - assert data["data"].mode == "L" # Check colors - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") -def test_loading_bbox(): - from ptbench.data.tbx11k_simplified import dataset_with_bboxes - - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 3 - - assert "data" in data - assert data["data"].size == (512, 512) - - assert data["data"].mode == "L" # Check colors - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - assert "bboxes" in data - assert data["bboxes"] == "none" or data["bboxes"][0].startswith( - "{'xmin':" - ) - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset_with_bboxes.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") -def test_check(): - from ptbench.data.tbx11k_simplified import dataset - - assert dataset.check() == 0 - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") -def test_check_bbox(): - from ptbench.data.tbx11k_simplified import dataset_with_bboxes - - assert dataset_with_bboxes.check() == 0 diff --git a/tests/test_11k_RS.py b/tests/test_11k_RS.py deleted file mode 100644 index 1f9d975d95e50c1166dff3d5f8fad939948956bd..0000000000000000000000000000000000000000 --- a/tests/test_11k_RS.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for Extended TBX11K simplified dataset split 1.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.tbx11k_simplified_RS import dataset - - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 2767 - - assert "validation" in subset - assert len(subset["validation"]) == 706 - - assert "test" in subset - assert len(subset["test"]) == 957 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 0-9 - for f in range(10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 3177 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 810 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 443 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") -def test_loading(): - from ptbench.data.tbx11k_simplified_RS import dataset - - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert len(data["data"]) == 14 # Check radiological signs - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) diff --git a/tests/test_11k_v2.py b/tests/test_11k_v2.py deleted file mode 100644 index 9feab3ce74ffe53d414771908ce3ba63a8db74de..0000000000000000000000000000000000000000 --- a/tests/test_11k_v2.py +++ /dev/null @@ -1,275 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for TBX11K simplified dataset split 2.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.tbx11k_simplified_v2 import dataset - - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 5241 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 1335 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 1793 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 0-8 - for f in range(9): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 6003 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 1529 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 837 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 9 - subset = dataset.subsets("fold_9") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 6003 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 1530 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 836 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency_bbox(): - from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes - - # Default protocol - subset = dataset_with_bboxes.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 5241 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 1335 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 1793 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Check bounding boxes - for s in subset["train"]: - assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - - # Cross-validation fold 0-8 - for f in range(9): - subset = dataset_with_bboxes.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 6003 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 1529 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 837 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Check bounding boxes - for s in subset["train"]: - assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - - # Cross-validation fold 9 - subset = dataset_with_bboxes.subsets("fold_9") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 6003 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 1530 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 836 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Check bounding boxes - for s in subset["train"]: - assert s.bboxes == "none" or s.bboxes[0].startswith("{'xmin':") - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") -def test_loading(): - from ptbench.data.tbx11k_simplified_v2 import dataset - - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert data["data"].size == (512, 512) - - assert data["data"].mode == "L" # Check colors - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") -def test_loading_bbox(): - from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes - - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 3 - - assert "data" in data - assert data["data"].size == (512, 512) - - assert data["data"].mode == "L" # Check colors - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - assert "bboxes" in data - assert data["bboxes"] == "none" or data["bboxes"][0].startswith( - "{'xmin':" - ) - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset_with_bboxes.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") -def test_check(): - from ptbench.data.tbx11k_simplified_v2 import dataset - - assert dataset.check() == 0 - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified_v2") -def test_check_bbox(): - from ptbench.data.tbx11k_simplified_v2 import dataset_with_bboxes - - assert dataset_with_bboxes.check() == 0 diff --git a/tests/test_11k_v2_RS.py b/tests/test_11k_v2_RS.py deleted file mode 100644 index 590d2872ac915ef9f61eaea91d45dd2e9d911f4c..0000000000000000000000000000000000000000 --- a/tests/test_11k_v2_RS.py +++ /dev/null @@ -1,118 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for Extended TBX11K simplified dataset split 2.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.tbx11k_simplified_v2_RS import dataset - - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 5241 - - assert "validation" in subset - assert len(subset["validation"]) == 1335 - - assert "test" in subset - assert len(subset["test"]) == 1793 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 0-8 - for f in range(9): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 6003 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 1529 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 837 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 9 - subset = dataset.subsets("fold_9") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 6003 - for s in subset["train"]: - assert s.key.startswith("images/") - - assert "validation" in subset - assert len(subset["validation"]) == 1530 - for s in subset["validation"]: - assert s.key.startswith("images/") - - assert "test" in subset - assert len(subset["test"]) == 836 - for s in subset["test"]: - assert s.key.startswith("images/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k_simplified") -def test_loading(): - from ptbench.data.tbx11k_simplified_v2_RS import dataset - - def _check_sample(s): - data = s.data - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert len(data["data"]) == 14 # Check radiological signs - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) diff --git a/tests/test_ch_RS.py b/tests/test_ch_RS.py deleted file mode 100644 index 0f1fa60252a21f5658729d27dcb1cf82e8c81fe3..0000000000000000000000000000000000000000 --- a/tests/test_ch_RS.py +++ /dev/null @@ -1,119 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for Extended Shenzhen dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.shenzhen_RS import dataset - - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 422 - - assert "validation" in subset - assert len(subset["validation"]) == 107 - - assert "test" in subset - assert len(subset["test"]) == 133 - for s in subset["test"]: - assert s.key.startswith("CXR_png/CHNCXR_0") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation folds 0-1 - for f in range(2): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 476 - for s in subset["train"]: - assert s.key.startswith("CXR_png/CHNCXR_0") - - assert "validation" in subset - assert len(subset["validation"]) == 119 - for s in subset["validation"]: - assert s.key.startswith("CXR_png/CHNCXR_0") - - assert "test" in subset - assert len(subset["test"]) == 67 - for s in subset["test"]: - assert s.key.startswith("CXR_png/CHNCXR_0") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation folds 2-9 - for f in range(2, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 476 - for s in subset["train"]: - assert s.key.startswith("CXR_png/CHNCXR_0") - - assert "validation" in subset - assert len(subset["validation"]) == 120 - for s in subset["validation"]: - assert s.key.startswith("CXR_png/CHNCXR_0") - - assert "test" in subset - assert len(subset["test"]) == 66 - for s in subset["test"]: - assert s.key.startswith("CXR_png/CHNCXR_0") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -def test_loading(): - from ptbench.data.shenzhen_RS import dataset - - def _check_sample(s): - data = s.data - - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert len(data["data"]) == 14 # Check radiological signs - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) diff --git a/tests/test_hivtb_RS.py b/tests/test_hivtb_RS.py deleted file mode 100644 index 9a440598c9316e3fc6bab487f2f2d6123c1b666c..0000000000000000000000000000000000000000 --- a/tests/test_hivtb_RS.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for HIV-TB_RS dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.hivtb_RS import dataset - - # Cross-validation fold 0-2 - for f in range(3): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 174 - for s in subset["train"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") - - assert "validation" in subset - assert len(subset["validation"]) == 44 - for s in subset["validation"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") - - assert "test" in subset - assert len(subset["test"]) == 25 - for s in subset["test"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 3-9 - for f in range(3, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 175 - for s in subset["train"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") - - assert "validation" in subset - assert len(subset["validation"]) == 44 - for s in subset["validation"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") - - assert "test" in subset - assert len(subset["test"]) == 24 - for s in subset["test"]: - assert s.key.startswith("HIV-TB_Algorithm_study_X-rays/") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -def test_loading(): - from ptbench.data.hivtb_RS import dataset - - def _check_sample(s): - data = s.data - - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert len(data["data"]) == 14 # Check radiological signs - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("fold_0") - for s in subset["train"][:limit]: - _check_sample(s) diff --git a/tests/test_in_RS.py b/tests/test_in_RS.py deleted file mode 100644 index 551265ad9546bff576c46f8fcb167585cacf7892..0000000000000000000000000000000000000000 --- a/tests/test_in_RS.py +++ /dev/null @@ -1,115 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for Extended Indian dataset.""" - -import pytest - -dataset = None - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 83 - - assert "validation" in subset - assert len(subset["validation"]) == 20 - - assert "test" in subset - assert len(subset["test"]) == 52 - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 0-4 - for f in range(5): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 111 - for s in subset["train"]: - assert s.key.startswith("DatasetA") - - assert "validation" in subset - assert len(subset["validation"]) == 28 - for s in subset["validation"]: - assert s.key.startswith("DatasetA") - - assert "test" in subset - assert len(subset["test"]) == 16 - for s in subset["test"]: - assert s.key.startswith("DatasetA") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 5-9 - for f in range(5, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 112 - for s in subset["train"]: - assert s.key.startswith("DatasetA") - - assert "validation" in subset - assert len(subset["validation"]) == 28 - for s in subset["validation"]: - assert s.key.startswith("DatasetA") - - assert "test" in subset - assert len(subset["test"]) == 15 - for s in subset["test"]: - assert s.key.startswith("DatasetA") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -def test_loading(): - def _check_sample(s): - data = s.data - - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert len(data["data"]) == 14 # Check radiological signs - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) diff --git a/tests/test_indian.py b/tests/test_indian.py index 91adf0d42083cfa45869e41ccbe9fc1923edd039..de428b161aa775b8dc69678b65fcb87c7da2d50f 100644 --- a/tests/test_indian.py +++ b/tests/test_indian.py @@ -3,125 +3,94 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Tests for Indian (a.k.a. -database A/database B) dataset. +dataset A/dataset B) dataset. """ -import pytest -import torch - -from ptbench.data.indian.datamodule import make_split - - -def _check_split( - split_filename: str, - lengths: dict[str, int], - prefix: str = "Dataset", - possible_labels: list[int] = [0, 1], -): - """Runs a simple consistence check on the data split. - - Parameters - ---------- - - split_filename - This is the split we will check - - lenghts - A dictionary that contains keys matching those of the split (this will - be checked). The values of the dictionary should correspond to the - sizes of each of the datasets in the split. +import importlib - prefix - Each file named in a split should start with this prefix. - - possible_labels - These are the list of possible labels contained in any split. - """ - - split = make_split(split_filename) - - assert len(split) == len(lengths) - - for k in lengths.keys(): - # dataset must have been declared - assert k in split - - assert len(split[k]) == lengths[k] - for s in split[k]: - assert s[0].startswith(prefix) - assert s[1] in possible_labels +import pytest -def _check_loaded_batch( - batch, - size: int = 1, - prefix: str = "Dataset", - possible_labels: list[int] = [0, 1], +def id_function(val): + if isinstance(val, dict): + return str(val) + return repr(val) + + +@pytest.mark.parametrize( + "split,lenghts", + [ + ("default", dict(train=83, validation=20, test=52)), + ("fold-0", dict(train=111, validation=28, test=16)), + ("fold-1", dict(train=111, validation=28, test=16)), + ("fold-2", dict(train=111, validation=28, test=16)), + ("fold-3", dict(train=111, validation=28, test=16)), + ("fold-4", dict(train=111, validation=28, test=16)), + ("fold-5", dict(train=112, validation=28, test=15)), + ("fold-6", dict(train=112, validation=28, test=15)), + ("fold-7", dict(train=112, validation=28, test=15)), + ("fold-8", dict(train=112, validation=28, test=15)), + ("fold-9", dict(train=112, validation=28, test=15)), + ], + ids=id_function, # just changes how pytest prints it +) +def test_protocol_consistency( + database_checkers, split: str, lenghts: dict[str, int] ): - """Checks the consistence of an individual (loaded) batch. - - Parameters - ---------- + from ptbench.data.indian.datamodule import make_split - batch - The loaded batch to be checked. - - prefix - Each file named in a split should start with this prefix. - - possible_labels - These are the list of possible labels contained in any split. - """ - - assert len(batch) == 2 # data, metadata - - assert isinstance(batch[0], torch.Tensor) - assert batch[0].shape[0] == size # mini-batch size - assert batch[0].shape[1] == 1 # grayscale images - assert batch[0].shape[2] == batch[0].shape[3] # image is square - - assert isinstance(batch[1], dict) # metadata - assert len(batch[1]) == 2 # label and name - - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) - - assert "name" in batch[1] - assert all([k.startswith(prefix) for k in batch[1]["name"]]) - - -def test_protocol_consistency(): - _check_split( - "default.json", - lengths=dict(train=83, validation=20, test=52), + database_checkers.check_split( + make_split(f"{split}.json"), + lengths=lenghts, + prefixes=("DatasetA/Training", "DatasetA/Testing"), + possible_labels=(0, 1), ) - # Cross-validation fold 0-4 - for k in range(5): - _check_split( - f"fold-{k}.json", - lengths=dict(train=111, validation=28, test=16), - ) - - # Cross-validation fold 5-9 - for k in range(5, 10): - _check_split( - f"fold-{k}.json", - lengths=dict(train=112, validation=28, test=15), - ) - @pytest.mark.skip_if_rc_var_not_set("datadir.indian") -def test_loading(): - from ptbench.data.indian.default import datamodule +@pytest.mark.parametrize( + "dataset", + [ + "train", + "validation", + "test", + ], +) +@pytest.mark.parametrize( + "name", + [ + "default", + "fold_0", + "fold_1", + "fold_2", + "fold_3", + "fold_4", + "fold_5", + "fold_6", + "fold_7", + "fold_8", + "fold_9", + ], +) +def test_loading(database_checkers, name: str, dataset: str): + datamodule = importlib.import_module( + f".{name}", "ptbench.data.indian" + ).datamodule datamodule.model_transforms = [] # should be done before setup() datamodule.setup("predict") # sets up all datasets - for loader in datamodule.predict_dataloader().values(): - limit = 5 # limit load checking - for batch in loader: - if limit == 0: - break - _check_loaded_batch(batch) - limit -= 1 + loader = datamodule.predict_dataloader()[dataset] + + limit = 3 # limit load checking + for batch in loader: + if limit == 0: + break + database_checkers.check_loaded_batch( + batch, + batch_size=1, + color_planes=1, + prefixes=("DatasetA/Training", "DatasetA/Testing"), + possible_labels=(0, 1), + ) + limit -= 1 diff --git a/tests/test_mc_RS.py b/tests/test_mc_RS.py deleted file mode 100644 index a0175c23f8d00556b041792b378001eec1430962..0000000000000000000000000000000000000000 --- a/tests/test_mc_RS.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for Extended Montgomery dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.montgomery_RS import dataset - - # Default protocol - subset = dataset.subsets("default") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 88 - - assert "validation" in subset - assert len(subset["validation"]) == 22 - - assert "test" in subset - assert len(subset["test"]) == 28 - for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 0-7 - for f in range(8): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 99 - for s in subset["train"]: - assert s.key.startswith("CXR_png/MCUCXR_0") - - assert "validation" in subset - assert len(subset["validation"]) == 25 - for s in subset["validation"]: - assert s.key.startswith("CXR_png/MCUCXR_0") - - assert "test" in subset - assert len(subset["test"]) == 14 - for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 8-9 - for f in range(8, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 100 - for s in subset["train"]: - assert s.key.startswith("CXR_png/MCUCXR_0") - - assert "validation" in subset - assert len(subset["validation"]) == 25 - for s in subset["validation"]: - assert s.key.startswith("CXR_png/MCUCXR_0") - - assert "test" in subset - assert len(subset["test"]) == 13 - for s in subset["test"]: - assert s.key.startswith("CXR_png/MCUCXR_0") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_loading(): - from ptbench.data.montgomery_RS import dataset - - def _check_sample(s): - data = s.data - - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert len(data["data"]) == 14 # Check radiological signs - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("default") - for s in subset["train"][:limit]: - _check_sample(s) diff --git a/tests/test_mc_ch_RS.py b/tests/test_mc_ch_RS.py deleted file mode 100644 index 7e91622906cf9b6497e6d73ee3b29d0e9cf0b42b..0000000000000000000000000000000000000000 --- a/tests/test_mc_ch_RS.py +++ /dev/null @@ -1,285 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for the aggregated Montgomery-Shenzhen dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_dataset_consistency(): - from ptbench.configs.datasets.mc_ch_RS import default as mc_ch_RS - from ptbench.configs.datasets.mc_ch_RS import fold_0 as mc_ch_f0 - from ptbench.configs.datasets.mc_ch_RS import fold_1 as mc_ch_f1 - from ptbench.configs.datasets.mc_ch_RS import fold_2 as mc_ch_f2 - from ptbench.configs.datasets.mc_ch_RS import fold_3 as mc_ch_f3 - from ptbench.configs.datasets.mc_ch_RS import fold_4 as mc_ch_f4 - from ptbench.configs.datasets.mc_ch_RS import fold_5 as mc_ch_f5 - from ptbench.configs.datasets.mc_ch_RS import fold_6 as mc_ch_f6 - from ptbench.configs.datasets.mc_ch_RS import fold_7 as mc_ch_f7 - from ptbench.configs.datasets.mc_ch_RS import fold_8 as mc_ch_f8 - from ptbench.configs.datasets.mc_ch_RS import fold_9 as mc_ch_f9 - from ptbench.configs.datasets.montgomery_RS import default as mc_RS - from ptbench.configs.datasets.montgomery_RS import fold_0 as mc_f0 - from ptbench.configs.datasets.montgomery_RS import fold_1 as mc_f1 - from ptbench.configs.datasets.montgomery_RS import fold_2 as mc_f2 - from ptbench.configs.datasets.montgomery_RS import fold_3 as mc_f3 - from ptbench.configs.datasets.montgomery_RS import fold_4 as mc_f4 - from ptbench.configs.datasets.montgomery_RS import fold_5 as mc_f5 - from ptbench.configs.datasets.montgomery_RS import fold_6 as mc_f6 - from ptbench.configs.datasets.montgomery_RS import fold_7 as mc_f7 - from ptbench.configs.datasets.montgomery_RS import fold_8 as mc_f8 - from ptbench.configs.datasets.montgomery_RS import fold_9 as mc_f9 - from ptbench.configs.datasets.shenzhen_RS import default as ch_RS - from ptbench.configs.datasets.shenzhen_RS import fold_0 as ch_f0 - from ptbench.configs.datasets.shenzhen_RS import fold_1 as ch_f1 - from ptbench.configs.datasets.shenzhen_RS import fold_2 as ch_f2 - from ptbench.configs.datasets.shenzhen_RS import fold_3 as ch_f3 - from ptbench.configs.datasets.shenzhen_RS import fold_4 as ch_f4 - from ptbench.configs.datasets.shenzhen_RS import fold_5 as ch_f5 - from ptbench.configs.datasets.shenzhen_RS import fold_6 as ch_f6 - from ptbench.configs.datasets.shenzhen_RS import fold_7 as ch_f7 - from ptbench.configs.datasets.shenzhen_RS import fold_8 as ch_f8 - from ptbench.configs.datasets.shenzhen_RS import fold_9 as ch_f9 - - # Default protocol - mc_ch_RS_dataset = mc_ch_RS.dataset - assert isinstance(mc_ch_RS_dataset, dict) - - mc_RS_dataset = mc_RS.dataset - ch_RS_dataset = ch_RS.dataset - - assert "train" in mc_ch_RS_dataset - assert len(mc_ch_RS_dataset["train"]) == len(mc_RS_dataset["train"]) + len( - ch_RS_dataset["train"] - ) - - assert "validation" in mc_ch_RS_dataset - assert len(mc_ch_RS_dataset["validation"]) == len( - mc_RS_dataset["validation"] - ) + len(ch_RS_dataset["validation"]) - - assert "test" in mc_ch_RS_dataset - assert len(mc_ch_RS_dataset["test"]) == len(mc_RS_dataset["test"]) + len( - ch_RS_dataset["test"] - ) - - # f0 protocol - mc_ch_dataset = mc_ch_f0.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f0.dataset - ch_dataset = ch_f0.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f1 protocol - mc_ch_dataset = mc_ch_f1.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f1.dataset - ch_dataset = ch_f1.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f2 protocol - mc_ch_dataset = mc_ch_f2.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f2.dataset - ch_dataset = ch_f2.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f3 protocol - mc_ch_dataset = mc_ch_f3.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f3.dataset - ch_dataset = ch_f3.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f4 protocol - mc_ch_dataset = mc_ch_f4.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f4.dataset - ch_dataset = ch_f4.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f5 protocol - mc_ch_dataset = mc_ch_f5.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f5.dataset - ch_dataset = ch_f5.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f6 protocol - mc_ch_dataset = mc_ch_f6.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f6.dataset - ch_dataset = ch_f6.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f7 protocol - mc_ch_dataset = mc_ch_f7.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f7.dataset - ch_dataset = ch_f7.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f8 protocol - mc_ch_dataset = mc_ch_f8.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f8.dataset - ch_dataset = ch_f8.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) - - # f9 protocol - mc_ch_dataset = mc_ch_f9.dataset - assert isinstance(mc_ch_dataset, dict) - - mc_dataset = mc_f9.dataset - ch_dataset = ch_f9.dataset - - assert "train" in mc_ch_dataset - assert len(mc_ch_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) - - assert "validation" in mc_ch_dataset - assert len(mc_ch_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) - - assert "test" in mc_ch_dataset - assert len(mc_ch_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) diff --git a/tests/test_mc_ch_in_11k_RS.py b/tests/test_mc_ch_in_11k_RS.py deleted file mode 100644 index f8843befc928558ae9917eed3f41549dfd798e31..0000000000000000000000000000000000000000 --- a/tests/test_mc_ch_in_11k_RS.py +++ /dev/null @@ -1,442 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for the aggregated Montgomery-Shenzhen-Indian-tbx11k_simplified -dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_dataset_consistency(): - from ptbench.configs.datasets.indian_RS import default as indian_RS - from ptbench.configs.datasets.indian_RS import fold_0 as indian_f0 - from ptbench.configs.datasets.indian_RS import fold_1 as indian_f1 - from ptbench.configs.datasets.indian_RS import fold_2 as indian_f2 - from ptbench.configs.datasets.indian_RS import fold_3 as indian_f3 - from ptbench.configs.datasets.indian_RS import fold_4 as indian_f4 - from ptbench.configs.datasets.indian_RS import fold_5 as indian_f5 - from ptbench.configs.datasets.indian_RS import fold_6 as indian_f6 - from ptbench.configs.datasets.indian_RS import fold_7 as indian_f7 - from ptbench.configs.datasets.indian_RS import fold_8 as indian_f8 - from ptbench.configs.datasets.indian_RS import fold_9 as indian_f9 - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - default as mc_ch_in_11k_RS, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_0 as mc_ch_in_11k_f0, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_1 as mc_ch_in_11k_f1, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_2 as mc_ch_in_11k_f2, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_3 as mc_ch_in_11k_f3, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_4 as mc_ch_in_11k_f4, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_5 as mc_ch_in_11k_f5, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_6 as mc_ch_in_11k_f6, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_7 as mc_ch_in_11k_f7, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_8 as mc_ch_in_11k_f8, - ) - from ptbench.configs.datasets.mc_ch_in_11k_RS import ( - fold_9 as mc_ch_in_11k_f9, - ) - from ptbench.configs.datasets.montgomery_RS import default as mc_RS - from ptbench.configs.datasets.montgomery_RS import fold_0 as mc_f0 - from ptbench.configs.datasets.montgomery_RS import fold_1 as mc_f1 - from ptbench.configs.datasets.montgomery_RS import fold_2 as mc_f2 - from ptbench.configs.datasets.montgomery_RS import fold_3 as mc_f3 - from ptbench.configs.datasets.montgomery_RS import fold_4 as mc_f4 - from ptbench.configs.datasets.montgomery_RS import fold_5 as mc_f5 - from ptbench.configs.datasets.montgomery_RS import fold_6 as mc_f6 - from ptbench.configs.datasets.montgomery_RS import fold_7 as mc_f7 - from ptbench.configs.datasets.montgomery_RS import fold_8 as mc_f8 - from ptbench.configs.datasets.montgomery_RS import fold_9 as mc_f9 - from ptbench.configs.datasets.shenzhen_RS import default as ch_RS - from ptbench.configs.datasets.shenzhen_RS import fold_0 as ch_f0 - from ptbench.configs.datasets.shenzhen_RS import fold_1 as ch_f1 - from ptbench.configs.datasets.shenzhen_RS import fold_2 as ch_f2 - from ptbench.configs.datasets.shenzhen_RS import fold_3 as ch_f3 - from ptbench.configs.datasets.shenzhen_RS import fold_4 as ch_f4 - from ptbench.configs.datasets.shenzhen_RS import fold_5 as ch_f5 - from ptbench.configs.datasets.shenzhen_RS import fold_6 as ch_f6 - from ptbench.configs.datasets.shenzhen_RS import fold_7 as ch_f7 - from ptbench.configs.datasets.shenzhen_RS import fold_8 as ch_f8 - from ptbench.configs.datasets.shenzhen_RS import fold_9 as ch_f9 - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - default as tbx11k_RS, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_0 as tbx11k_f0, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_1 as tbx11k_f1, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_2 as tbx11k_f2, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_3 as tbx11k_f3, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_4 as tbx11k_f4, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_5 as tbx11k_f5, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_6 as tbx11k_f6, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_7 as tbx11k_f7, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_8 as tbx11k_f8, - ) - from ptbench.configs.datasets.tbx11k_simplified_RS import ( - fold_9 as tbx11k_f9, - ) - - # Default protocol - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_RS.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_RS_dataset = mc_RS.dataset - ch_RS_dataset = ch_RS.dataset - in_RS_dataset = indian_RS.dataset - tbx11k_RS_dataset = tbx11k_RS.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_RS_dataset["train"] - ) + len(ch_RS_dataset["train"]) + len(in_RS_dataset["train"]) + len( - tbx11k_RS_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_RS_dataset["validation"] - ) + len(ch_RS_dataset["validation"]) + len( - in_RS_dataset["validation"] - ) + len( - tbx11k_RS_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_RS_dataset["test"] - ) + len(ch_RS_dataset["test"]) + len(in_RS_dataset["test"]) + len( - tbx11k_RS_dataset["test"] - ) - - # Fold 0 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f0.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f0.dataset - ch_dataset = ch_f0.dataset - in_dataset = indian_f0.dataset - tbx11k_dataset = tbx11k_f0.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 1 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f1.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f1.dataset - ch_dataset = ch_f1.dataset - in_dataset = indian_f1.dataset - tbx11k_dataset = tbx11k_f1.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 2 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f2.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f2.dataset - ch_dataset = ch_f2.dataset - in_dataset = indian_f2.dataset - tbx11k_dataset = tbx11k_f2.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 3 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f3.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f3.dataset - ch_dataset = ch_f3.dataset - in_dataset = indian_f3.dataset - tbx11k_dataset = tbx11k_f3.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 4 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f4.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f4.dataset - ch_dataset = ch_f4.dataset - in_dataset = indian_f4.dataset - tbx11k_dataset = tbx11k_f4.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 5 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f5.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f5.dataset - ch_dataset = ch_f5.dataset - in_dataset = indian_f5.dataset - tbx11k_dataset = tbx11k_f5.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 6 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f6.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f6.dataset - ch_dataset = ch_f6.dataset - in_dataset = indian_f6.dataset - tbx11k_dataset = tbx11k_f6.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 7 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f7.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f7.dataset - ch_dataset = ch_f7.dataset - in_dataset = indian_f7.dataset - tbx11k_dataset = tbx11k_f7.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 8 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f8.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f8.dataset - ch_dataset = ch_f8.dataset - in_dataset = indian_f8.dataset - tbx11k_dataset = tbx11k_f8.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 9 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f9.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f9.dataset - ch_dataset = ch_f9.dataset - in_dataset = indian_f9.dataset - tbx11k_dataset = tbx11k_f9.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) diff --git a/tests/test_mc_ch_in_11kv2_RS.py b/tests/test_mc_ch_in_11kv2_RS.py deleted file mode 100644 index 3c128968613ebc33198643407283bb5810b2725f..0000000000000000000000000000000000000000 --- a/tests/test_mc_ch_in_11kv2_RS.py +++ /dev/null @@ -1,442 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for the aggregated Montgomery-Shenzhen-Indian-tbx11k_simplified_v2 -dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_dataset_consistency(): - from ptbench.configs.datasets.indian_RS import default as indian_RS - from ptbench.configs.datasets.indian_RS import fold_0 as indian_f0 - from ptbench.configs.datasets.indian_RS import fold_1 as indian_f1 - from ptbench.configs.datasets.indian_RS import fold_2 as indian_f2 - from ptbench.configs.datasets.indian_RS import fold_3 as indian_f3 - from ptbench.configs.datasets.indian_RS import fold_4 as indian_f4 - from ptbench.configs.datasets.indian_RS import fold_5 as indian_f5 - from ptbench.configs.datasets.indian_RS import fold_6 as indian_f6 - from ptbench.configs.datasets.indian_RS import fold_7 as indian_f7 - from ptbench.configs.datasets.indian_RS import fold_8 as indian_f8 - from ptbench.configs.datasets.indian_RS import fold_9 as indian_f9 - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - default as mc_ch_in_11k_RS, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_0 as mc_ch_in_11k_f0, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_1 as mc_ch_in_11k_f1, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_2 as mc_ch_in_11k_f2, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_3 as mc_ch_in_11k_f3, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_4 as mc_ch_in_11k_f4, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_5 as mc_ch_in_11k_f5, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_6 as mc_ch_in_11k_f6, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_7 as mc_ch_in_11k_f7, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_8 as mc_ch_in_11k_f8, - ) - from ptbench.configs.datasets.mc_ch_in_11kv2_RS import ( - fold_9 as mc_ch_in_11k_f9, - ) - from ptbench.configs.datasets.montgomery_RS import default as mc_RS - from ptbench.configs.datasets.montgomery_RS import fold_0 as mc_f0 - from ptbench.configs.datasets.montgomery_RS import fold_1 as mc_f1 - from ptbench.configs.datasets.montgomery_RS import fold_2 as mc_f2 - from ptbench.configs.datasets.montgomery_RS import fold_3 as mc_f3 - from ptbench.configs.datasets.montgomery_RS import fold_4 as mc_f4 - from ptbench.configs.datasets.montgomery_RS import fold_5 as mc_f5 - from ptbench.configs.datasets.montgomery_RS import fold_6 as mc_f6 - from ptbench.configs.datasets.montgomery_RS import fold_7 as mc_f7 - from ptbench.configs.datasets.montgomery_RS import fold_8 as mc_f8 - from ptbench.configs.datasets.montgomery_RS import fold_9 as mc_f9 - from ptbench.configs.datasets.shenzhen_RS import default as ch_RS - from ptbench.configs.datasets.shenzhen_RS import fold_0 as ch_f0 - from ptbench.configs.datasets.shenzhen_RS import fold_1 as ch_f1 - from ptbench.configs.datasets.shenzhen_RS import fold_2 as ch_f2 - from ptbench.configs.datasets.shenzhen_RS import fold_3 as ch_f3 - from ptbench.configs.datasets.shenzhen_RS import fold_4 as ch_f4 - from ptbench.configs.datasets.shenzhen_RS import fold_5 as ch_f5 - from ptbench.configs.datasets.shenzhen_RS import fold_6 as ch_f6 - from ptbench.configs.datasets.shenzhen_RS import fold_7 as ch_f7 - from ptbench.configs.datasets.shenzhen_RS import fold_8 as ch_f8 - from ptbench.configs.datasets.shenzhen_RS import fold_9 as ch_f9 - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - default as tbx11k_RS, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_0 as tbx11k_f0, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_1 as tbx11k_f1, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_2 as tbx11k_f2, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_3 as tbx11k_f3, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_4 as tbx11k_f4, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_5 as tbx11k_f5, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_6 as tbx11k_f6, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_7 as tbx11k_f7, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_8 as tbx11k_f8, - ) - from ptbench.configs.datasets.tbx11k_simplified_v2_RS import ( - fold_9 as tbx11k_f9, - ) - - # Default protocol - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_RS.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_RS_dataset = mc_RS.dataset - ch_RS_dataset = ch_RS.dataset - in_RS_dataset = indian_RS.dataset - tbx11k_RS_dataset = tbx11k_RS.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_RS_dataset["train"] - ) + len(ch_RS_dataset["train"]) + len(in_RS_dataset["train"]) + len( - tbx11k_RS_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_RS_dataset["validation"] - ) + len(ch_RS_dataset["validation"]) + len( - in_RS_dataset["validation"] - ) + len( - tbx11k_RS_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_RS_dataset["test"] - ) + len(ch_RS_dataset["test"]) + len(in_RS_dataset["test"]) + len( - tbx11k_RS_dataset["test"] - ) - - # Fold 0 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f0.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f0.dataset - ch_dataset = ch_f0.dataset - in_dataset = indian_f0.dataset - tbx11k_dataset = tbx11k_f0.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 1 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f1.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f1.dataset - ch_dataset = ch_f1.dataset - in_dataset = indian_f1.dataset - tbx11k_dataset = tbx11k_f1.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 2 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f2.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f2.dataset - ch_dataset = ch_f2.dataset - in_dataset = indian_f2.dataset - tbx11k_dataset = tbx11k_f2.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 3 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f3.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f3.dataset - ch_dataset = ch_f3.dataset - in_dataset = indian_f3.dataset - tbx11k_dataset = tbx11k_f3.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 4 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f4.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f4.dataset - ch_dataset = ch_f4.dataset - in_dataset = indian_f4.dataset - tbx11k_dataset = tbx11k_f4.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 5 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f5.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f5.dataset - ch_dataset = ch_f5.dataset - in_dataset = indian_f5.dataset - tbx11k_dataset = tbx11k_f5.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 6 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f6.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f6.dataset - ch_dataset = ch_f6.dataset - in_dataset = indian_f6.dataset - tbx11k_dataset = tbx11k_f6.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 7 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f7.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f7.dataset - ch_dataset = ch_f7.dataset - in_dataset = indian_f7.dataset - tbx11k_dataset = tbx11k_f7.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 8 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f8.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f8.dataset - ch_dataset = ch_f8.dataset - in_dataset = indian_f8.dataset - tbx11k_dataset = tbx11k_f8.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) - - # Fold 9 - mc_ch_in_11k_RS_dataset = mc_ch_in_11k_f9.dataset - assert isinstance(mc_ch_in_11k_RS_dataset, dict) - - mc_dataset = mc_f9.dataset - ch_dataset = ch_f9.dataset - in_dataset = indian_f9.dataset - tbx11k_dataset = tbx11k_f9.dataset - - assert "train" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["train"]) == len( - mc_dataset["train"] - ) + len(ch_dataset["train"]) + len(in_dataset["train"]) + len( - tbx11k_dataset["train"] - ) - - assert "validation" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) + len( - tbx11k_dataset["validation"] - ) - - assert "test" in mc_ch_in_11k_RS_dataset - assert len(mc_ch_in_11k_RS_dataset["test"]) == len( - mc_dataset["test"] - ) + len(ch_dataset["test"]) + len(in_dataset["test"]) + len( - tbx11k_dataset["test"] - ) diff --git a/tests/test_mc_ch_in_RS.py b/tests/test_mc_ch_in_RS.py deleted file mode 100644 index 1600f60ba45e0f503f34224b144d0ed84259d361..0000000000000000000000000000000000000000 --- a/tests/test_mc_ch_in_RS.py +++ /dev/null @@ -1,306 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for the aggregated Montgomery-Shenzhen-Indian dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_dataset_consistency(): - from ptbench.configs.datasets.indian_RS import default as indian_RS - from ptbench.configs.datasets.indian_RS import fold_0 as indian_f0 - from ptbench.configs.datasets.indian_RS import fold_1 as indian_f1 - from ptbench.configs.datasets.indian_RS import fold_2 as indian_f2 - from ptbench.configs.datasets.indian_RS import fold_3 as indian_f3 - from ptbench.configs.datasets.indian_RS import fold_4 as indian_f4 - from ptbench.configs.datasets.indian_RS import fold_5 as indian_f5 - from ptbench.configs.datasets.indian_RS import fold_6 as indian_f6 - from ptbench.configs.datasets.indian_RS import fold_7 as indian_f7 - from ptbench.configs.datasets.indian_RS import fold_8 as indian_f8 - from ptbench.configs.datasets.indian_RS import fold_9 as indian_f9 - from ptbench.configs.datasets.mc_ch_in_RS import default as mc_ch_in_RS - from ptbench.configs.datasets.mc_ch_in_RS import fold_0 as mc_ch_in_f0 - from ptbench.configs.datasets.mc_ch_in_RS import fold_1 as mc_ch_in_f1 - from ptbench.configs.datasets.mc_ch_in_RS import fold_2 as mc_ch_in_f2 - from ptbench.configs.datasets.mc_ch_in_RS import fold_3 as mc_ch_in_f3 - from ptbench.configs.datasets.mc_ch_in_RS import fold_4 as mc_ch_in_f4 - from ptbench.configs.datasets.mc_ch_in_RS import fold_5 as mc_ch_in_f5 - from ptbench.configs.datasets.mc_ch_in_RS import fold_6 as mc_ch_in_f6 - from ptbench.configs.datasets.mc_ch_in_RS import fold_7 as mc_ch_in_f7 - from ptbench.configs.datasets.mc_ch_in_RS import fold_8 as mc_ch_in_f8 - from ptbench.configs.datasets.mc_ch_in_RS import fold_9 as mc_ch_in_f9 - from ptbench.configs.datasets.montgomery_RS import default as mc_RS - from ptbench.configs.datasets.montgomery_RS import fold_0 as mc_f0 - from ptbench.configs.datasets.montgomery_RS import fold_1 as mc_f1 - from ptbench.configs.datasets.montgomery_RS import fold_2 as mc_f2 - from ptbench.configs.datasets.montgomery_RS import fold_3 as mc_f3 - from ptbench.configs.datasets.montgomery_RS import fold_4 as mc_f4 - from ptbench.configs.datasets.montgomery_RS import fold_5 as mc_f5 - from ptbench.configs.datasets.montgomery_RS import fold_6 as mc_f6 - from ptbench.configs.datasets.montgomery_RS import fold_7 as mc_f7 - from ptbench.configs.datasets.montgomery_RS import fold_8 as mc_f8 - from ptbench.configs.datasets.montgomery_RS import fold_9 as mc_f9 - from ptbench.configs.datasets.shenzhen_RS import default as ch_RS - from ptbench.configs.datasets.shenzhen_RS import fold_0 as ch_f0 - from ptbench.configs.datasets.shenzhen_RS import fold_1 as ch_f1 - from ptbench.configs.datasets.shenzhen_RS import fold_2 as ch_f2 - from ptbench.configs.datasets.shenzhen_RS import fold_3 as ch_f3 - from ptbench.configs.datasets.shenzhen_RS import fold_4 as ch_f4 - from ptbench.configs.datasets.shenzhen_RS import fold_5 as ch_f5 - from ptbench.configs.datasets.shenzhen_RS import fold_6 as ch_f6 - from ptbench.configs.datasets.shenzhen_RS import fold_7 as ch_f7 - from ptbench.configs.datasets.shenzhen_RS import fold_8 as ch_f8 - from ptbench.configs.datasets.shenzhen_RS import fold_9 as ch_f9 - - # Default protocol - mc_ch_in_RS_dataset = mc_ch_in_RS.dataset - assert isinstance(mc_ch_in_RS_dataset, dict) - - mc_RS_dataset = mc_RS.dataset - ch_RS_dataset = ch_RS.dataset - in_RS_dataset = indian_RS.dataset - - assert "train" in mc_ch_in_RS_dataset - assert len(mc_ch_in_RS_dataset["train"]) == len( - mc_RS_dataset["train"] - ) + len(ch_RS_dataset["train"]) + len(in_RS_dataset["train"]) - - assert "validation" in mc_ch_in_RS_dataset - assert len(mc_ch_in_RS_dataset["validation"]) == len( - mc_RS_dataset["validation"] - ) + len(ch_RS_dataset["validation"]) + len(in_RS_dataset["validation"]) - - assert "test" in mc_ch_in_RS_dataset - assert len(mc_ch_in_RS_dataset["test"]) == len(mc_RS_dataset["test"]) + len( - ch_RS_dataset["test"] - ) + len(in_RS_dataset["test"]) - - # Fold 0 - mc_ch_in_dataset = mc_ch_in_f0.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f0.dataset - ch_dataset = ch_f0.dataset - in_dataset = indian_f0.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 1 - mc_ch_in_dataset = mc_ch_in_f1.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f1.dataset - ch_dataset = ch_f1.dataset - in_dataset = indian_f1.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 2 - mc_ch_in_dataset = mc_ch_in_f2.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f2.dataset - ch_dataset = ch_f2.dataset - in_dataset = indian_f2.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 3 - mc_ch_in_dataset = mc_ch_in_f3.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f3.dataset - ch_dataset = ch_f3.dataset - in_dataset = indian_f3.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 4 - mc_ch_in_dataset = mc_ch_in_f4.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f4.dataset - ch_dataset = ch_f4.dataset - in_dataset = indian_f4.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 5 - mc_ch_in_dataset = mc_ch_in_f5.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f5.dataset - ch_dataset = ch_f5.dataset - in_dataset = indian_f5.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 6 - mc_ch_in_dataset = mc_ch_in_f6.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f6.dataset - ch_dataset = ch_f6.dataset - in_dataset = indian_f6.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 7 - mc_ch_in_dataset = mc_ch_in_f7.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f7.dataset - ch_dataset = ch_f7.dataset - in_dataset = indian_f7.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 8 - mc_ch_in_dataset = mc_ch_in_f8.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f8.dataset - ch_dataset = ch_f8.dataset - in_dataset = indian_f8.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) - - # Fold 9 - mc_ch_in_dataset = mc_ch_in_f9.dataset - assert isinstance(mc_ch_in_dataset, dict) - - mc_dataset = mc_f9.dataset - ch_dataset = ch_f9.dataset - in_dataset = indian_f9.dataset - - assert "train" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["train"]) == len(mc_dataset["train"]) + len( - ch_dataset["train"] - ) + len(in_dataset["train"]) - - assert "validation" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["validation"]) == len( - mc_dataset["validation"] - ) + len(ch_dataset["validation"]) + len(in_dataset["validation"]) - - assert "test" in mc_ch_in_dataset - assert len(mc_ch_in_dataset["test"]) == len(mc_dataset["test"]) + len( - ch_dataset["test"] - ) + len(in_dataset["test"]) diff --git a/tests/test_mc_ch_in_pc_RS.py b/tests/test_mc_ch_in_pc_RS.py deleted file mode 100644 index 1014fb1df1c3d606ea1abe69693e7244c06d667f..0000000000000000000000000000000000000000 --- a/tests/test_mc_ch_in_pc_RS.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for the aggregated Montgomery-Shenzhen-Indian-Padchest(TB) dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_dataset_consistency(): - from ptbench.configs.datasets.indian_RS import default as in_RS - from ptbench.configs.datasets.mc_ch_in_pc_RS import default as mc_ch_in_pc - from ptbench.configs.datasets.montgomery_RS import default as mc_RS - from ptbench.configs.datasets.padchest_RS import tb_idiap as pc_RS - from ptbench.configs.datasets.shenzhen_RS import default as ch_RS - - # Default protocol - mc_ch_in_pc_dataset = mc_ch_in_pc.dataset - assert isinstance(mc_ch_in_pc_dataset, dict) - - mc_RS_dataset = mc_RS.dataset - ch_RS_dataset = ch_RS.dataset - in_RS_dataset = in_RS.dataset - pc_RS_dataset = pc_RS.dataset - - assert "train" in mc_ch_in_pc_dataset - assert len(mc_ch_in_pc_dataset["train"]) == len( - mc_RS_dataset["train"] - ) + len(ch_RS_dataset["train"]) + len(in_RS_dataset["train"]) + len( - pc_RS_dataset["train"] - ) - - assert "validation" in mc_ch_in_pc_dataset - assert len(mc_ch_in_pc_dataset["validation"]) == len( - mc_RS_dataset["validation"] - ) + len(ch_RS_dataset["validation"]) + len( - in_RS_dataset["validation"] - ) + len( - pc_RS_dataset["validation"] - ) - - assert "test" in mc_ch_in_pc_dataset - assert len(mc_ch_in_pc_dataset["test"]) == len(mc_RS_dataset["test"]) + len( - ch_RS_dataset["test"] - ) + len(in_RS_dataset["test"]) + len(pc_RS_dataset["test"]) diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 718889d2ff2703d0c175b2f9fe5149a9da8ac71b..973350aa1d8f166f4fbd943d81cffe4de7a25fe9 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -6,117 +6,43 @@ import importlib import pytest -import torch -from ptbench.data.montgomery.datamodule import make_split +def id_function(val): + if isinstance(val, dict): + return str(val) + return repr(val) -def _check_split( - split_filename: str, - lengths: dict[str, int], - prefix: str = "CXR_png/MCUCXR_0", - possible_labels: list[int] = [0, 1], -): - """Runs a simple consistence check on the data split. - - Parameters - ---------- - - split_filename - This is the split we will check - - lenghts - A dictionary that contains keys matching those of the split (this will - be checked). The values of the dictionary should correspond to the - sizes of each of the datasets in the split. - - prefix - Each file named in a split should start with this prefix. - - possible_labels - These are the list of possible labels contained in any split. - """ - - split = make_split(split_filename) - - assert len(split) == len(lengths) - for k in lengths.keys(): - # dataset must have been declared - assert k in split - - assert len(split[k]) == lengths[k] - for s in split[k]: - assert s[0].startswith(prefix) - assert s[1] in possible_labels - - -def _check_loaded_batch( - batch, - size: int = 1, - prefix: str = "CXR_png/MCUCXR_0", - possible_labels: list[int] = [0, 1], +@pytest.mark.parametrize( + "split,lenghts", + [ + ("default", dict(train=88, validation=22, test=28)), + ("fold-0", dict(train=99, validation=25, test=14)), + ("fold-1", dict(train=99, validation=25, test=14)), + ("fold-2", dict(train=99, validation=25, test=14)), + ("fold-3", dict(train=99, validation=25, test=14)), + ("fold-4", dict(train=99, validation=25, test=14)), + ("fold-5", dict(train=99, validation=25, test=14)), + ("fold-6", dict(train=99, validation=25, test=14)), + ("fold-7", dict(train=99, validation=25, test=14)), + ("fold-8", dict(train=100, validation=25, test=13)), + ("fold-9", dict(train=100, validation=25, test=13)), + ], + ids=id_function, # just changes how pytest prints it +) +def test_protocol_consistency( + database_checkers, split: str, lenghts: dict[str, int] ): - """Checks the consistence of an individual (loaded) batch. - - Parameters - ---------- + from ptbench.data.montgomery.datamodule import make_split - batch - The loaded batch to be checked. - - size - The mini-batch size - - prefix - Each file named in a split should start with this prefix. - - possible_labels - These are the list of possible labels contained in any split. - """ - - assert len(batch) == 2 # data, metadata - - assert isinstance(batch[0], torch.Tensor) - assert batch[0].shape[0] == size # mini-batch size - assert batch[0].shape[1] == 1 # grayscale images - assert batch[0].shape[2] == batch[0].shape[3] # image is square - - assert isinstance(batch[1], dict) # metadata - assert len(batch[1]) == 2 # label and name - - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) - - assert "name" in batch[1] - assert all([k.startswith(prefix) for k in batch[1]["name"]]) - - # use the code below to view generated images - # from torchvision.transforms.functional import to_pil_image - # to_pil_image(batch[0][0]).show() - # __import__("pdb").set_trace() - - -def test_protocol_consistency(): - _check_split( - "default.json", - lengths=dict(train=88, validation=22, test=28), + database_checkers.check_split( + make_split(f"{split}.json"), + lengths=lenghts, + prefixes=("CXR_png/MCUCXR_0",), + possible_labels=(0, 1), ) - # Cross-validation fold 0-7 - for k in range(8): - _check_split( - f"fold-{k}.json", - lengths=dict(train=99, validation=25, test=14), - ) - - # Cross-validation fold 8-9 - for k in range(8, 10): - _check_split( - f"fold-{k}.json", - lengths=dict(train=100, validation=25, test=13), - ) - @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.parametrize( @@ -143,7 +69,7 @@ def test_protocol_consistency(): "fold_9", ], ) -def test_loading(name: str, dataset: str): +def test_loading(database_checkers, name: str, dataset: str): datamodule = importlib.import_module( f".{name}", "ptbench.data.montgomery" ).datamodule @@ -157,5 +83,11 @@ def test_loading(name: str, dataset: str): for batch in loader: if limit == 0: break - _check_loaded_batch(batch) + database_checkers.check_loaded_batch( + batch, + batch_size=1, + color_planes=1, + prefixes=("CXR_png/MCUCXR_0",), + possible_labels=(0, 1), + ) limit -= 1 diff --git a/tests/test_pc_RS.py b/tests/test_pc_RS.py deleted file mode 100644 index c62ef78001634cd04ecc93eaf0c7dc4fb0d22ffd..0000000000000000000000000000000000000000 --- a/tests/test_pc_RS.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for Extended Padchest dataset.""" - -import pytest - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - from ptbench.data.padchest_RS import dataset - - # tb_idiap protocol - subset = dataset.subsets("tb_idiap") - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 160 - - assert "validation" in subset - assert len(subset["validation"]) == 40 - - assert "test" in subset - assert len(subset["test"]) == 50 - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -def test_loading(): - from ptbench.data.padchest_RS import dataset - - def _check_sample(s): - data = s.data - - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert len(data["data"]) == 14 # Check radiological signs - - assert "label" in data - assert data["label"] in [0.0, 1.0] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("tb_idiap") - for s in subset["train"][:limit]: - _check_sample(s) diff --git a/tests/test_shenzhen.py b/tests/test_shenzhen.py index 30c69543a77d24715f2a815192d2e0abac41b9c5..12532cdee7e799cbe9cf39173e86bc02952f7f0b 100644 --- a/tests/test_shenzhen.py +++ b/tests/test_shenzhen.py @@ -3,122 +3,91 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Tests for Shenzhen dataset.""" -import pytest -import torch - -from ptbench.data.shenzhen.datamodule import make_split - - -def _check_split( - split_filename: str, - lengths: dict[str, int], - prefix: str = "CXR_png/CHNCXR_0", - possible_labels: list[int] = [0, 1], -): - """Runs a simple consistence check on the data split. - - Parameters - ---------- - - split_filename - This is the split we will check - - lenghts - A dictionary that contains keys matching those of the split (this will - be checked). The values of the dictionary should correspond to the - sizes of each of the datasets in the split. +import importlib - prefix - Each file named in a split should start with this prefix. - - possible_labels - These are the list of possible labels contained in any split. - """ - - split = make_split(split_filename) - - assert len(split) == len(lengths) - - for k in lengths.keys(): - # dataset must have been declared - assert k in split - - assert len(split[k]) == lengths[k] - for s in split[k]: - assert s[0].startswith(prefix) - assert s[1] in possible_labels +import pytest -def _check_loaded_batch( - batch, - size: int = 1, - prefix: str = "CXR_png/CHNCXR_0", - possible_labels: list[int] = [0, 1], +def id_function(val): + if isinstance(val, dict): + return str(val) + return repr(val) + + +@pytest.mark.parametrize( + "split,lenghts", + [ + ("default", dict(train=422, validation=107, test=133)), + ("fold-0", dict(train=476, validation=119, test=67)), + ("fold-1", dict(train=476, validation=119, test=67)), + ("fold-2", dict(train=476, validation=120, test=66)), + ("fold-3", dict(train=476, validation=120, test=66)), + ("fold-4", dict(train=476, validation=120, test=66)), + ("fold-5", dict(train=476, validation=120, test=66)), + ("fold-6", dict(train=476, validation=120, test=66)), + ("fold-7", dict(train=476, validation=120, test=66)), + ("fold-8", dict(train=476, validation=120, test=66)), + ("fold-9", dict(train=476, validation=120, test=66)), + ], + ids=id_function, # just changes how pytest prints it +) +def test_protocol_consistency( + database_checkers, split: str, lenghts: dict[str, int] ): - """Checks the consistence of an individual (loaded) batch. - - Parameters - ---------- + from ptbench.data.shenzhen.datamodule import make_split - batch - The loaded batch to be checked. - - prefix - Each file named in a split should start with this prefix. - - possible_labels - These are the list of possible labels contained in any split. - """ - - assert len(batch) == 2 # data, metadata - - assert isinstance(batch[0], torch.Tensor) - assert batch[0].shape[0] == size # mini-batch size - assert batch[0].shape[1] == 1 # grayscale images - assert batch[0].shape[2] == batch[0].shape[3] # image is square - - assert isinstance(batch[1], dict) # metadata - assert len(batch[1]) == 2 # label and name - - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) - - assert "name" in batch[1] - assert all([k.startswith(prefix) for k in batch[1]["name"]]) - - -def test_protocol_consistency(): - _check_split( - "default.json", - lengths=dict(train=422, validation=107, test=133), + database_checkers.check_split( + make_split(f"{split}.json"), + lengths=lenghts, + prefixes=("CXR_png/CHNCXR_0",), + possible_labels=(0, 1), ) - # Cross-validation fold 0-1 - for k in range(2): - _check_split( - f"fold-{k}.json", - lengths=dict(train=476, validation=119, test=67), - ) - - # Cross-validation fold 2-9 - for k in range(2, 10): - _check_split( - f"fold-{k}.json", - lengths=dict(train=476, validation=120, test=66), - ) - @pytest.mark.skip_if_rc_var_not_set("datadir.shenzhen") -def test_loading(): - from ptbench.data.shenzhen.default import datamodule +@pytest.mark.parametrize( + "dataset", + [ + "train", + "validation", + "test", + ], +) +@pytest.mark.parametrize( + "name", + [ + "default", + "fold_0", + "fold_1", + "fold_2", + "fold_3", + "fold_4", + "fold_5", + "fold_6", + "fold_7", + "fold_8", + "fold_9", + ], +) +def test_loading(database_checkers, name: str, dataset: str): + datamodule = importlib.import_module( + f".{name}", "ptbench.data.shenzhen" + ).datamodule datamodule.model_transforms = [] # should be done before setup() datamodule.setup("predict") # sets up all datasets - for loader in datamodule.predict_dataloader().values(): - limit = 5 # limit load checking - for batch in loader: - if limit == 0: - break - _check_loaded_batch(batch) - limit -= 1 + loader = datamodule.predict_dataloader()[dataset] + + limit = 3 # limit load checking + for batch in loader: + if limit == 0: + break + database_checkers.check_loaded_batch( + batch, + batch_size=1, + color_planes=1, + prefixes=("CXR_png/CHNCXR_0",), + possible_labels=(0, 1), + ) + limit -= 1 diff --git a/tests/test_tbpoc_RS.py b/tests/test_tbpoc_RS.py deleted file mode 100644 index 72f3e5968ba755319c9c36b3b206e342639b7152..0000000000000000000000000000000000000000 --- a/tests/test_tbpoc_RS.py +++ /dev/null @@ -1,92 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later -"""Tests for TB-POC_RS dataset.""" - -import pytest - -dataset = None - - -@pytest.mark.skip(reason="Test need to be updated") -def test_protocol_consistency(): - # Cross-validation fold 0-6 - for f in range(7): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 292 - for s in subset["train"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") - - assert "validation" in subset - assert len(subset["validation"]) == 74 - for s in subset["validation"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") - - assert "test" in subset - assert len(subset["test"]) == 41 - for s in subset["test"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - # Cross-validation fold 7-9 - for f in range(7, 10): - subset = dataset.subsets("fold_" + str(f)) - assert len(subset) == 3 - - assert "train" in subset - assert len(subset["train"]) == 293 - for s in subset["train"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") - - assert "validation" in subset - assert len(subset["validation"]) == 74 - for s in subset["validation"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") - - assert "test" in subset - assert len(subset["test"]) == 40 - for s in subset["test"]: - assert s.key.upper().startswith("TBPOC_CXR/TBPOC-") - - # Check labels - for s in subset["train"]: - assert s.label in [0.0, 1.0] - - for s in subset["validation"]: - assert s.label in [0.0, 1.0] - - for s in subset["test"]: - assert s.label in [0.0, 1.0] - - -@pytest.mark.skip(reason="Test need to be updated") -def test_loading(): - def _check_sample(s): - data = s.data - - assert isinstance(data, dict) - assert len(data) == 2 - - assert "data" in data - assert len(data["data"]) == 14 # Check radiological signs - - assert "label" in data - assert data["label"] in [0, 1] # Check labels - - limit = 30 # use this to limit testing to first images only, else None - - subset = dataset.subsets("fold_0") - for s in subset["train"][:limit]: - _check_sample(s) diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py new file mode 100644 index 0000000000000000000000000000000000000000..39c54174b6b8a21ca95cfd24adae217c569be78a --- /dev/null +++ b/tests/test_tbx11k.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Tests for TBX11K dataset.""" + +import importlib +import typing + +import pytest + + +def id_function(val): + if isinstance(val, (dict, tuple)): + return repr(val) + return repr(val) + + +@pytest.mark.parametrize( + "split,lenghts,prefixes", + [ + ( + "v1-healthy-vs-atb", + dict(train=2767, validation=706, test=957), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-0", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-1", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-2", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-3", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-4", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-5", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-6", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-7", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-8", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v1-fold-9", + dict(train=3177, validation=810, test=443), + ("imgs/health", "imgs/tb"), + ), + ( + "v2-others-vs-atb", + dict(train=5241, validation=1335, test=1793), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-0", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-1", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-2", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-3", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-4", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-5", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-6", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-7", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-8", + dict(train=6003, validation=1529, test=837), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ( + "v2-fold-9", + dict(train=6003, validation=1530, test=836), + ("imgs/health", "imgs/sick", "imgs/tb"), + ), + ], + ids=id_function, # just changes how pytest prints it +) +def test_protocol_consistency( + database_checkers, + split: str, + lenghts: dict[str, int], + prefixes: typing.Sequence[str], +): + from ptbench.data.tbx11k.datamodule import make_split + + database_checkers.check_split( + make_split(f"{split}.json"), + lengths=lenghts, + prefixes=prefixes, + possible_labels=(0, 1), + ) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbx11k") +@pytest.mark.parametrize( + "dataset", + [ + "train", + "validation", + "test", + ], +) +@pytest.mark.parametrize( + "name,prefixes", + [ + ("v1_healthy_vs_atb", ("imgs/health", "imgs/tb")), + ("v1_fold_0", ("imgs/health", "imgs/tb")), + ("v1_fold_1", ("imgs/health", "imgs/tb")), + ("v1_fold_2", ("imgs/health", "imgs/tb")), + ("v1_fold_3", ("imgs/health", "imgs/tb")), + ("v1_fold_4", ("imgs/health", "imgs/tb")), + ("v1_fold_5", ("imgs/health", "imgs/tb")), + ("v1_fold_6", ("imgs/health", "imgs/tb")), + ("v1_fold_7", ("imgs/health", "imgs/tb")), + ("v1_fold_8", ("imgs/health", "imgs/tb")), + ("v1_fold_9", ("imgs/health", "imgs/tb")), + ("v2_others_vs_atb", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_0", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_1", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_2", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_3", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_4", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_5", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_6", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_7", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_8", ("imgs/health", "imgs/sick", "imgs/tb")), + ("v2_fold_9", ("imgs/health", "imgs/sick", "imgs/tb")), + ], +) +def test_loading( + database_checkers, name: str, dataset: str, prefixes: typing.Sequence[str] +): + datamodule = importlib.import_module( + f".{name}", "ptbench.data.tbx11k" + ).datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.setup("predict") # sets up all datasets + + loader = datamodule.predict_dataloader()[dataset] + + limit = 3 # limit load checking + for batch in loader: + if limit == 0: + break + database_checkers.check_loaded_batch( + batch, + batch_size=1, + color_planes=3, + prefixes=prefixes, + possible_labels=(0, 1), + ) + limit -= 1 + + +# TODO: Tests for loading bounding boxes: +# if patient has active tb, then has to have 1 or more bounding boxes +# if patient does not have active tb, there should be no bounding boxes +# bounding boxes must be within image (512 x 512 pixels)