Newer
Older
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib

André Anjos
committed
import typing

André Anjos
committed
import torch
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
from mednet.data.typing import DatabaseSplit
@pytest.fixture
def datadir(request) -> pathlib.Path:
"""Return the directory in which the test is sitting. Check the pytest documentation for more information.
Parameters
----------
request
Information of the requesting test function.
Returns
-------
pathlib.Path
The directory in which the test is sitting.
"""
return pathlib.Path(request.module.__file__).parents[0] / "data"
def pytest_configure(config):
"""This function is run once for pytest setup.
Parameters
----------
config
Configuration values. Check the pytest documentation for more information.
"""
config.addinivalue_line(
"markers",
"skip_if_rc_var_not_set(name): this mark skips the test if a certain "
"~/.config/mednet.toml variable is not set",
)
config.addinivalue_line("markers", "slow: this mark indicates slow tests")
def pytest_runtest_setup(item):
"""This function is run for every test candidate in this directory.
The test is run if this function returns ``None``. To skip a test,
call ``pytest.skip()``, specifying a reason.
Parameters
----------
item
A test invocation item. Check the pytest documentation for more information.
from mednet.utils.rc import load_rc
rc = load_rc()
# iterates over all markers for the item being examined, get the first
# argument and accumulate these names
rc_names = [
mark.args[0]
for mark in item.iter_markers(name="skip_if_rc_var_not_set")
]
# checks all names mentioned are set in ~/.config/mednet.toml, otherwise,
# skip the test
if rc_names:
missing = [k for k in rc_names if rc.get(k) is None]
if any(missing):
pytest.skip(
f"Test skipped because {', '.join(missing)} is **not** "
f"set in ~/.config/mednet.toml"
from mednet.utils.rc import load_rc
rc = load_rc()
pytest.mark.skipif(
name not in rc,
reason=f"RC variable '{name}' is not set",
)
@pytest.fixture(scope="session")
def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli")

André Anjos
committed
class DatabaseCheckers:
"""Helpers for database tests."""
@staticmethod
def check_split(
split: DatabaseSplit,
lengths: dict[str, int],
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
):
"""Run a simple consistency check on the data split.

André Anjos
committed
Parameters
----------
split
An instance of DatabaseSplit.
lengths

André Anjos
committed
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
prefixes
Each file named in a split should start with at least one of these
prefixes.
possible_labels
These are the list of possible labels contained in any split.
"""
assert len(split) == len(lengths)
for k in lengths.keys():
# dataset must have been declared
assert k in split
assert len(split[k]) == lengths[k]
for s in split[k]:
assert any([s[0].startswith(k) for k in prefixes]), (
f"Sample with name {s[0]} does not start with any of the "
f"prefixes in {prefixes}"
)
if isinstance(s[1], list):
assert all([k in possible_labels for k in s[1]])
else:
assert s[1] in possible_labels

André Anjos
committed
@staticmethod
def check_loaded_batch(
batch,
batch_size: int,
color_planes: int,
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
expected_num_labels: int,
expected_image_shape: typing.Optional[tuple[int, ...]] = None,

André Anjos
committed
):
"""Check the consistency of an individual (loaded) batch.

André Anjos
committed
Parameters
----------
batch
The loaded batch to be checked.
batch_size
The mini-batch size.
color_planes
The number of color planes in the images.

André Anjos
committed
prefixes
Each file named in a split should start with at least one of these
prefixes.
possible_labels
These are the list of possible labels contained in any split.
expected_num_labels
The expected number of labels each sample should have.
expected_image_shape
The expected shape of the image (num_channels, width, height).

André Anjos
committed
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes

André Anjos
committed
if expected_image_shape:
assert all(
[data.shape == expected_image_shape for data in batch[0]]
)

André Anjos
committed
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels

André Anjos
committed
assert "name" in batch[1]
assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
)
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace()
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
@staticmethod
def check_image_quality(datamodule, reference_histogram_file):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
for split_name in ref_histogram_splits:
raw_samples = datamodule.splits[split_name][0][0]
# It is not possible to get a sample from a Dataset by name/path, only by index.
# This creates a dict of sample name to dataset index.
raw_samples_indices = {}
for idx, rs in enumerate(raw_samples):
raw_samples_indices[rs[0]] = idx
for ref_hist_path, ref_hist_data in ref_histogram_splits[
split_name
]:
# Get index in the dataset that will return the data corresponding to the specified sample name
dataset_sample_index = raw_samples_indices[ref_hist_path]
image_tensor = datamodule._datasets[split_name][
dataset_sample_index
][0]
img = to_pil_image(image_tensor)
histogram = img.histogram()
assert histogram == ref_hist_data

André Anjos
committed
@pytest.fixture
def database_checkers():
return DatabaseCheckers