Skip to content
Snippets Groups Projects
Commit 993600d9 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Rewrite tests for hiv-tb and tb-poc

parent d0743428
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #76736 failed
...@@ -162,7 +162,9 @@ class DatabaseCheckers: ...@@ -162,7 +162,9 @@ class DatabaseCheckers:
assert len(split[k]) == lengths[k] assert len(split[k]) == lengths[k]
for s in split[k]: for s in split[k]:
assert any([s[0].startswith(k) for k in prefixes]) assert any(
[s[0].startswith(k) for k in prefixes]
), f"Sample with name {s[0]} does not start with any of the prefixes in {prefixes}"
assert s[1] in possible_labels assert s[1] in possible_labels
@staticmethod @staticmethod
......
...@@ -3,127 +3,89 @@ ...@@ -3,127 +3,89 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for HIV-TB dataset.""" """Tests for HIV-TB dataset."""
import pytest import importlib
import torch
from ptbench.data.hivtb.datamodule import make_split
def _check_split(
split_filename: str,
lengths: dict[str, int],
prefix: str = "HIV-TB_Algorithm_study_X-rays/",
extension: str = ".BMP",
possible_labels: list[int] = [0, 1],
):
"""Runs a simple consistence check on the data split.
Parameters
----------
split_filename
This is the split we will check
lenghts
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
prefix import pytest
Each file named in a split should start with this prefix.
extension
Each file named in a split should end with this extension.
possible_labels
These are the list of possible labels contained in any split.
"""
split = make_split(split_filename)
assert len(split) == len(lengths)
for k in lengths.keys():
# dataset must have been declared
assert k in split
assert len(split[k]) == lengths[k]
for s in split[k]:
assert s[0].startswith(prefix)
assert s[0].endswith(extension)
assert s[1] in possible_labels
def _check_loaded_batch( def id_function(val):
batch, if isinstance(val, dict):
size: int = 1, return str(val)
prefix: str = "HIV-TB_Algorithm_study_X-rays/", return repr(val)
extension: str = ".BMP",
possible_labels: list[int] = [0, 1],
@pytest.mark.parametrize(
"split,lenghts",
[
("fold-0", dict(train=174, validation=44, test=25)),
("fold-1", dict(train=174, validation=44, test=25)),
("fold-2", dict(train=174, validation=44, test=25)),
("fold-3", dict(train=175, validation=44, test=24)),
("fold-4", dict(train=175, validation=44, test=24)),
("fold-5", dict(train=175, validation=44, test=24)),
("fold-6", dict(train=175, validation=44, test=24)),
("fold-7", dict(train=175, validation=44, test=24)),
("fold-8", dict(train=175, validation=44, test=24)),
("fold-9", dict(train=175, validation=44, test=24)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers, split: str, lenghts: dict[str, int]
): ):
"""Checks the consistence of an individual (loaded) batch. from ptbench.data.hivtb.datamodule import make_split
Parameters database_checkers.check_split(
---------- make_split(f"{split}.json"),
lengths=lenghts,
batch prefixes=("HIV-TB_Algorithm_study_X-rays",),
The loaded batch to be checked. possible_labels=(0, 1),
)
prefix
Each file named in a split should start with this prefix.
extension
Each file named in a split should end with this extension.
possible_labels
These are the list of possible labels contained in any split.
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == size # mini-batch size
assert batch[0].shape[1] == 1 # grayscale images
assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
assert "name" in batch[1]
assert all([k.startswith(prefix) for k in batch[1]["name"]])
assert all([k.endswith(extension) for k in batch[1]["name"]])
def test_protocol_consistency():
# Cross-validation fold 0-2
for k in range(3):
_check_split(
f"fold-{k}.json",
lengths=dict(train=174, validation=44, test=25),
)
# Cross-validation fold 3-9
for k in range(3, 10):
_check_split(
f"fold-{k}.json",
lengths=dict(train=175, validation=44, test=24),
)
@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb") @pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
def test_loading(): @pytest.mark.parametrize(
from ptbench.data.hivtb.fold_0 import datamodule "dataset",
[
"train",
"validation",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"fold_0",
"fold_1",
"fold_2",
"fold_3",
"fold_4",
"fold_5",
"fold_6",
"fold_7",
"fold_8",
"fold_9",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module(
f".{name}", "ptbench.data.hivtb"
).datamodule
datamodule.model_transforms = [] # should be done before setup() datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets datamodule.setup("predict") # sets up all datasets
for loader in datamodule.predict_dataloader().values(): loader = datamodule.predict_dataloader()[dataset]
limit = 5 # limit load checking
for batch in loader: limit = 3 # limit load checking
if limit == 0: for batch in loader:
break if limit == 0:
_check_loaded_batch(batch) break
limit -= 1 database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=1,
prefixes=("HIV-TB_Algorithm_study_X-rays",),
possible_labels=(0, 1),
)
limit -= 1
...@@ -3,127 +3,95 @@ ...@@ -3,127 +3,95 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for TB-POC dataset.""" """Tests for TB-POC dataset."""
import pytest import importlib
import torch
from ptbench.data.tbpoc.datamodule import make_split
def _check_split(
split_filename: str,
lengths: dict[str, int],
prefix: str = "TBPOC_CXR/",
extension: str = ".jpeg",
possible_labels: list[int] = [0, 1],
):
"""Runs a simple consistence check on the data split.
Parameters
----------
split_filename
This is the split we will check
lenghts
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
prefix
Each file named in a split should start with this prefix.
extension import pytest
Each file named in a split should end with this extension.
possible_labels
These are the list of possible labels contained in any split.
"""
split = make_split(split_filename)
assert len(split) == len(lengths)
for k in lengths.keys():
# dataset must have been declared
assert k in split
assert len(split[k]) == lengths[k]
for s in split[k]:
# assert s[0].startswith(prefix)
assert s[0].endswith(extension)
assert s[1] in possible_labels
def _check_loaded_batch( def id_function(val):
batch, if isinstance(val, dict):
size: int = 1, return str(val)
prefix: str = "TBPOC_CXR/", return repr(val)
extension: str = ".jpeg",
possible_labels: list[int] = [0, 1],
@pytest.mark.parametrize(
"split,lenghts",
[
("fold-0", dict(train=292, validation=74, test=41)),
("fold-1", dict(train=292, validation=74, test=41)),
("fold-2", dict(train=292, validation=74, test=41)),
("fold-3", dict(train=292, validation=74, test=41)),
("fold-4", dict(train=292, validation=74, test=41)),
("fold-5", dict(train=292, validation=74, test=41)),
("fold-6", dict(train=292, validation=74, test=41)),
("fold-7", dict(train=293, validation=74, test=40)),
("fold-8", dict(train=293, validation=74, test=40)),
("fold-9", dict(train=293, validation=74, test=40)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers, split: str, lenghts: dict[str, int]
): ):
"""Checks the consistence of an individual (loaded) batch. from ptbench.data.tbpoc.datamodule import make_split
Parameters database_checkers.check_split(
---------- make_split(f"{split}.json"),
lengths=lenghts,
batch prefixes=(
The loaded batch to be checked. "TBPOC_CXR/TBPOC-",
"TBPOC_CXR/tbpoc-",
prefix ),
Each file named in a split should start with this prefix. possible_labels=(0, 1),
)
extension
Each file named in a split should end with this extension.
@pytest.mark.skip_if_rc_var_not_set("datadir.tbpoc")
possible_labels @pytest.mark.parametrize(
These are the list of possible labels contained in any split. "dataset",
""" [
"train",
assert len(batch) == 2 # data, metadata "validation",
"test",
assert isinstance(batch[0], torch.Tensor) ],
assert batch[0].shape[0] == size # mini-batch size )
assert batch[0].shape[1] == 1 # grayscale images @pytest.mark.parametrize(
assert batch[0].shape[2] == batch[0].shape[3] # image is square "name",
[
assert isinstance(batch[1], dict) # metadata "fold_0",
assert len(batch[1]) == 2 # label and name "fold_1",
"fold_2",
assert "label" in batch[1] "fold_3",
assert all([k in possible_labels for k in batch[1]["label"]]) "fold_4",
"fold_5",
assert "name" in batch[1] "fold_6",
# assert all([k.startswith(prefix) for k in batch[1]["name"]]) "fold_7",
assert all([k.endswith(extension) for k in batch[1]["name"]]) "fold_8",
"fold_9",
],
def test_protocol_consistency(): )
# Cross-validation fold 0-6 def test_loading(database_checkers, name: str, dataset: str):
for k in range(7): datamodule = importlib.import_module(
_check_split( f".{name}", "ptbench.data.tbpoc"
f"fold-{k}.json", ).datamodule
lengths=dict(train=292, validation=74, test=41),
)
# Cross-validation fold 7-9
for k in range(7, 10):
_check_split(
f"fold-{k}.json",
lengths=dict(train=293, validation=74, test=40),
)
@pytest.mark.skip_if_rc_var_not_set("datadir.hivtb")
def test_loading():
from ptbench.data.tbpoc.fold_0 import datamodule
datamodule.model_transforms = [] # should be done before setup() datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets datamodule.setup("predict") # sets up all datasets
for loader in datamodule.predict_dataloader().values(): loader = datamodule.predict_dataloader()[dataset]
limit = 5 # limit load checking
for batch in loader: limit = 3 # limit load checking
if limit == 0: for batch in loader:
break if limit == 0:
_check_loaded_batch(batch) break
limit -= 1 database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=1,
prefixes=(
"TBPOC_CXR/TBPOC-",
"TBPOC_CXR/tbpoc-",
),
possible_labels=(0, 1),
)
limit -= 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment