diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py new file mode 100644 index 0000000000000000000000000000000000000000..3610a498e6b86d9c729477c44b6ace0448cc3e0b --- /dev/null +++ b/tests/test_montgomery.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Tests for Montgomery dataset.""" + +import pytest +import torch + +from ptbench.data.montgomery.datamodule import make_split + + +def _check_split( + split_filename: str, + lengths: dict[str, int], + prefix: str = "CXR_png/MCUCXR_0", + possible_labels: list[int] = [0, 1], +): + """Runs a simple consistence check on the data split. + + Parameters + ---------- + + split_filename + This is the split we will check + + lenghts + A dictionary that contains keys matching those of the split (this will + be checked). The values of the dictionary should correspond to the + sizes of each of the datasets in the split. + + prefix + Each file named in a split should start with this prefix. + + possible_labels + These are the list of possible labels contained in any split. + """ + + split = make_split(split_filename) + + assert len(split) == len(lengths) + + for k in lengths.keys(): + # dataset must have been declared + assert k in split + + assert len(split[k]) == lengths[k] + for s in split[k]: + assert s[0].startswith(prefix) + assert s[1] in possible_labels + + +def _check_loaded_batch( + batch, + size: int = 1, + prefix: str = "CXR_png/MCUCXR_0", + possible_labels: list[int] = [0, 1], +): + """Checks the consistence of an individual (loaded) batch. + + Parameters + ---------- + + batch + The loaded batch to be checked. + + prefix + Each file named in a split should start with this prefix. + + possible_labels + These are the list of possible labels contained in any split. + """ + + assert len(batch) == 2 # data, metadata + + assert isinstance(batch[0], torch.Tensor) + assert batch[0].shape[0] == size # mini-batch size + assert batch[0].shape[1] == 1 # grayscale images + assert batch[0].shape[2] == batch[0].shape[3] # image is square + + assert isinstance(batch[1], dict) # metadata + assert len(batch[1]) == 2 # label and name + + assert "label" in batch[1] + assert all([k in possible_labels for k in batch[1]["label"]]) + + assert "name" in batch[1] + assert all([k.startswith(prefix) for k in batch[1]["name"]]) + + +def test_protocol_consistency(): + _check_split( + "default.json", + lengths=dict(train=88, validation=22, test=28), + ) + + # Cross-validation fold 0-7 + for k in range(8): + _check_split( + f"fold_{k}.json", + lengths=dict(train=99, validation=25, test=14), + ) + + # Cross-validation fold 8-9 + for k in range(8, 10): + _check_split( + f"fold_{k}.json", + lengths=dict(train=100, validation=25, test=13), + ) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") +def test_loading(): + from ptbench.data.montgomery.default import datamodule + + datamodule.model_transforms = [] # should be done before setup() + datamodule.setup("predict") # sets up all datasets + + for loader in datamodule.predict_dataloader().values(): + limit = 5 # limit load checking + for batch in loader: + if limit == 0: + break + _check_loaded_batch(batch) + limit -= 1