diff --git a/tests/conftest.py b/tests/conftest.py index 75b802a252994df5e2b5793bcd2c3422f5b9ad15..6e992cdd04092d04ff40a0efdb6293c8cff7dd7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -162,7 +162,9 @@ class DatabaseCheckers: assert len(split[k]) == lengths[k] for s in split[k]: - assert any([s[0].startswith(k) for k in prefixes]) + assert any( + [s[0].startswith(k) for k in prefixes] + ), f"Sample with name {s[0]} does not start with any of the prefixes in {prefixes}" assert s[1] in possible_labels @staticmethod diff --git a/tests/test_hivtb.py b/tests/test_hivtb.py index 920f1574ccabe924c9609d1fd1533a0524fb83e4..eb6513f0046d70a7846e310a7e47a837c7222d8e 100644 --- a/tests/test_hivtb.py +++ b/tests/test_hivtb.py @@ -3,127 +3,89 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Tests for HIV-TB dataset.""" -import pytest -import torch - -from ptbench.data.hivtb.datamodule import make_split - - -def _check_split( - split_filename: str, - lengths: dict[str, int], - prefix: str = "HIV-TB_Algorithm_study_X-rays/", - extension: str = ".BMP", - possible_labels: list[int] = [0, 1], -): - """Runs a simple consistence check on the data split. - - Parameters - ---------- - - split_filename - This is the split we will check - - lenghts - A dictionary that contains keys matching those of the split (this will - be checked). The values of the dictionary should correspond to the - sizes of each of the datasets in the split. +import importlib - prefix - Each file named in a split should start with this prefix. - - extension - Each file named in a split should end with this extension. - - possible_labels - These are the list of possible labels contained in any split. - """ - - split = make_split(split_filename) - - assert len(split) == len(lengths) - - for k in lengths.keys(): - # dataset must have been declared - assert k in split - - assert len(split[k]) == lengths[k] - for s in split[k]: - assert s[0].startswith(prefix) - assert s[0].endswith(extension) - assert s[1] in possible_labels +import pytest -def _check_loaded_batch( - batch, - size: int = 1, - prefix: str = "HIV-TB_Algorithm_study_X-rays/", - extension: str = ".BMP", - possible_labels: list[int] = [0, 1], +def id_function(val): + if isinstance(val, dict): + return str(val) + return repr(val) + + +@pytest.mark.parametrize( + "split,lenghts", + [ + ("fold-0", dict(train=174, validation=44, test=25)), + ("fold-1", dict(train=174, validation=44, test=25)), + ("fold-2", dict(train=174, validation=44, test=25)), + ("fold-3", dict(train=175, validation=44, test=24)), + ("fold-4", dict(train=175, validation=44, test=24)), + ("fold-5", dict(train=175, validation=44, test=24)), + ("fold-6", dict(train=175, validation=44, test=24)), + ("fold-7", dict(train=175, validation=44, test=24)), + ("fold-8", dict(train=175, validation=44, test=24)), + ("fold-9", dict(train=175, validation=44, test=24)), + ], + ids=id_function, # just changes how pytest prints it +) +def test_protocol_consistency( + database_checkers, split: str, lenghts: dict[str, int] ): - """Checks the consistence of an individual (loaded) batch. + from ptbench.data.hivtb.datamodule import make_split - Parameters - ---------- - - batch - The loaded batch to be checked. - - prefix - Each file named in a split should start with this prefix. - - extension - Each file named in a split should end with this extension. - - possible_labels - These are the list of possible labels contained in any split. - """ - - assert len(batch) == 2 # data, metadata - - assert isinstance(batch[0], torch.Tensor) - assert batch[0].shape[0] == size # mini-batch size - assert batch[0].shape[1] == 1 # grayscale images - assert batch[0].shape[2] == batch[0].shape[3] # image is square - - assert isinstance(batch[1], dict) # metadata - assert len(batch[1]) == 2 # label and name - - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) - - assert "name" in batch[1] - assert all([k.startswith(prefix) for k in batch[1]["name"]]) - assert all([k.endswith(extension) for k in batch[1]["name"]]) - - -def test_protocol_consistency(): - # Cross-validation fold 0-2 - for k in range(3): - _check_split( - f"fold-{k}.json", - lengths=dict(train=174, validation=44, test=25), - ) - - # Cross-validation fold 3-9 - for k in range(3, 10): - _check_split( - f"fold-{k}.json", - lengths=dict(train=175, validation=44, test=24), - ) + database_checkers.check_split( + make_split(f"{split}.json"), + lengths=lenghts, + prefixes=("HIV-TB_Algorithm_study_X-rays",), + possible_labels=(0, 1), + ) @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") -def test_loading(): - from ptbench.data.hivtb.fold_0 import datamodule +@pytest.mark.parametrize( + "dataset", + [ + "train", + "validation", + "test", + ], +) +@pytest.mark.parametrize( + "name", + [ + "fold_0", + "fold_1", + "fold_2", + "fold_3", + "fold_4", + "fold_5", + "fold_6", + "fold_7", + "fold_8", + "fold_9", + ], +) +def test_loading(database_checkers, name: str, dataset: str): + datamodule = importlib.import_module( + f".{name}", "ptbench.data.hivtb" + ).datamodule datamodule.model_transforms = [] # should be done before setup() datamodule.setup("predict") # sets up all datasets - for loader in datamodule.predict_dataloader().values(): - limit = 5 # limit load checking - for batch in loader: - if limit == 0: - break - _check_loaded_batch(batch) - limit -= 1 + loader = datamodule.predict_dataloader()[dataset] + + limit = 3 # limit load checking + for batch in loader: + if limit == 0: + break + database_checkers.check_loaded_batch( + batch, + batch_size=1, + color_planes=1, + prefixes=("HIV-TB_Algorithm_study_X-rays",), + possible_labels=(0, 1), + ) + limit -= 1 diff --git a/tests/test_tbpoc.py b/tests/test_tbpoc.py index ee34d8d09d14379dfc08dcc4adebc33ffedd18b9..7125efb195a058f36c785268c110c765f0890d65 100644 --- a/tests/test_tbpoc.py +++ b/tests/test_tbpoc.py @@ -3,127 +3,95 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Tests for TB-POC dataset.""" -import pytest -import torch - -from ptbench.data.tbpoc.datamodule import make_split - - -def _check_split( - split_filename: str, - lengths: dict[str, int], - prefix: str = "TBPOC_CXR/", - extension: str = ".jpeg", - possible_labels: list[int] = [0, 1], -): - """Runs a simple consistence check on the data split. - - Parameters - ---------- - - split_filename - This is the split we will check - - lenghts - A dictionary that contains keys matching those of the split (this will - be checked). The values of the dictionary should correspond to the - sizes of each of the datasets in the split. - - prefix - Each file named in a split should start with this prefix. +import importlib - extension - Each file named in a split should end with this extension. - - possible_labels - These are the list of possible labels contained in any split. - """ - - split = make_split(split_filename) - - assert len(split) == len(lengths) - - for k in lengths.keys(): - # dataset must have been declared - assert k in split - - assert len(split[k]) == lengths[k] - for s in split[k]: - # assert s[0].startswith(prefix) - assert s[0].endswith(extension) - assert s[1] in possible_labels +import pytest -def _check_loaded_batch( - batch, - size: int = 1, - prefix: str = "TBPOC_CXR/", - extension: str = ".jpeg", - possible_labels: list[int] = [0, 1], +def id_function(val): + if isinstance(val, dict): + return str(val) + return repr(val) + + +@pytest.mark.parametrize( + "split,lenghts", + [ + ("fold-0", dict(train=292, validation=74, test=41)), + ("fold-1", dict(train=292, validation=74, test=41)), + ("fold-2", dict(train=292, validation=74, test=41)), + ("fold-3", dict(train=292, validation=74, test=41)), + ("fold-4", dict(train=292, validation=74, test=41)), + ("fold-5", dict(train=292, validation=74, test=41)), + ("fold-6", dict(train=292, validation=74, test=41)), + ("fold-7", dict(train=293, validation=74, test=40)), + ("fold-8", dict(train=293, validation=74, test=40)), + ("fold-9", dict(train=293, validation=74, test=40)), + ], + ids=id_function, # just changes how pytest prints it +) +def test_protocol_consistency( + database_checkers, split: str, lenghts: dict[str, int] ): - """Checks the consistence of an individual (loaded) batch. - - Parameters - ---------- - - batch - The loaded batch to be checked. - - prefix - Each file named in a split should start with this prefix. - - extension - Each file named in a split should end with this extension. - - possible_labels - These are the list of possible labels contained in any split. - """ - - assert len(batch) == 2 # data, metadata - - assert isinstance(batch[0], torch.Tensor) - assert batch[0].shape[0] == size # mini-batch size - assert batch[0].shape[1] == 1 # grayscale images - assert batch[0].shape[2] == batch[0].shape[3] # image is square - - assert isinstance(batch[1], dict) # metadata - assert len(batch[1]) == 2 # label and name - - assert "label" in batch[1] - assert all([k in possible_labels for k in batch[1]["label"]]) - - assert "name" in batch[1] - # assert all([k.startswith(prefix) for k in batch[1]["name"]]) - assert all([k.endswith(extension) for k in batch[1]["name"]]) - - -def test_protocol_consistency(): - # Cross-validation fold 0-6 - for k in range(7): - _check_split( - f"fold-{k}.json", - lengths=dict(train=292, validation=74, test=41), - ) - - # Cross-validation fold 7-9 - for k in range(7, 10): - _check_split( - f"fold-{k}.json", - lengths=dict(train=293, validation=74, test=40), - ) - - -@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") -def test_loading(): - from ptbench.data.tbpoc.fold_0 import datamodule + from ptbench.data.tbpoc.datamodule import make_split + + database_checkers.check_split( + make_split(f"{split}.json"), + lengths=lenghts, + prefixes=( + "TBPOC_CXR/TBPOC-", + "TBPOC_CXR/tbpoc-", + ), + possible_labels=(0, 1), + ) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc") +@pytest.mark.parametrize( + "dataset", + [ + "train", + "validation", + "test", + ], +) +@pytest.mark.parametrize( + "name", + [ + "fold_0", + "fold_1", + "fold_2", + "fold_3", + "fold_4", + "fold_5", + "fold_6", + "fold_7", + "fold_8", + "fold_9", + ], +) +def test_loading(database_checkers, name: str, dataset: str): + datamodule = importlib.import_module( + f".{name}", "ptbench.data.tbpoc" + ).datamodule datamodule.model_transforms = [] # should be done before setup() datamodule.setup("predict") # sets up all datasets - for loader in datamodule.predict_dataloader().values(): - limit = 5 # limit load checking - for batch in loader: - if limit == 0: - break - _check_loaded_batch(batch) - limit -= 1 + loader = datamodule.predict_dataloader()[dataset] + + limit = 3 # limit load checking + for batch in loader: + if limit == 0: + break + database_checkers.check_loaded_batch( + batch, + batch_size=1, + color_planes=1, + prefixes=( + "TBPOC_CXR/TBPOC-", + "TBPOC_CXR/tbpoc-", + ), + possible_labels=(0, 1), + ) + limit -= 1