Skip to content
Snippets Groups Projects
Commit bfe67849 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[tests] Add tests for drive database

parent 3bd9b318
No related branches found
No related tags found
1 merge request!46Create common library
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import typing
import numpy
import pytest
import torch
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit
@pytest.fixture
def datadir(request) -> pathlib.Path:
"""Return the directory in which the test is sitting. Check the pytest
documentation for more information.
Parameters
----------
request
Information of the requesting test function.
Returns
-------
pathlib.Path
The directory in which the test is sitting.
"""
return pathlib.Path(request.module.__file__).parents[0] / "data"
def pytest_configure(config):
"""This function is run once for pytest setup.
Parameters
----------
config
Configuration values. Check the pytest documentation for more
information.
"""
config.addinivalue_line(
"markers",
"skip_if_rc_var_not_set(name): this mark skips the test if a certain "
"~/.config/mednet.libs.classification.toml variable is not set",
)
config.addinivalue_line("markers", "slow: this mark indicates slow tests")
def pytest_runtest_setup(item):
"""This function is run for every test candidate in this directory.
The test is run if this function returns ``None``. To skip a test,
call ``pytest.skip()``, specifying a reason.
Parameters
----------
item
A test invocation item. Check the pytest documentation for more
information.
"""
from mednet.libs.classification.utils.rc import load_rc
rc = load_rc()
# iterates over all markers for the item being examined, get the first
# argument and accumulate these names
rc_names = [
mark.args[0]
for mark in item.iter_markers(name="skip_if_rc_var_not_set")
]
# checks all names mentioned are set in ~/.config/mednet.libs.classification.toml, otherwise,
# skip the test
if rc_names:
missing = [k for k in rc_names if rc.get(k) is None]
if any(missing):
pytest.skip(
f"Test skipped because {', '.join(missing)} is **not** "
f"set in ~/.config/mednet.libs.classification.toml",
)
def rc_variable_set(name):
from mednet.libs.classification.utils.rc import load_rc
rc = load_rc()
pytest.mark.skipif(
name not in rc,
reason=f"RC variable '{name}' is not set",
)
@pytest.fixture(scope="session")
def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli")
class DatabaseCheckers:
"""Helpers for database tests."""
@staticmethod
def check_split(
split: DatabaseSplit,
lengths: dict[str, int],
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
):
"""Run a simple consistency check on the data split.
Parameters
----------
split
An instance of DatabaseSplit.
lengths
A dictionary that contains keys matching those of the split (this
will be checked). The values of the dictionary should correspond
to the sizes of each of the datasets in the split.
prefixes
Each file named in a split should start with at least one of these
prefixes.
possible_labels
These are the list of possible labels contained in any split.
"""
assert len(split) == len(lengths)
for k in lengths.keys():
# dataset must have been declared
assert k in split
assert len(split[k]) == lengths[k]
for s in split[k]:
assert any([s[0].startswith(k) for k in prefixes]), (
f"Sample with name {s[0]} does not start with any of the "
f"prefixes in {prefixes}"
)
if isinstance(s[1], list):
assert all([k in possible_labels for k in s[1]])
else:
assert s[1] in possible_labels
@staticmethod
def check_loaded_batch(
batch,
batch_size: int,
color_planes: int,
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
expected_num_labels: int,
expected_image_shape: tuple[int, ...] | None = None,
):
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
The mini-batch size.
color_planes
The number of color planes in the images.
prefixes
Each file named in a split should start with at least one of these
prefixes.
possible_labels
These are the list of possible labels contained in any split.
expected_num_labels
The expected number of labels each sample should have.
expected_image_shape
The expected shape of the image (num_channels, width, height).
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes
if expected_image_shape:
assert all(
[data.shape == expected_image_shape for data in batch[0]],
)
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
assert "target" in batch[1]
assert all([k in possible_labels for k in batch[1]["target"]])
if expected_num_labels:
assert len(batch[1]["target"]) == expected_num_labels
assert "name" in batch[1]
assert all(
[
any([k.startswith(j) for j in prefixes])
for k in batch[1]["name"]
],
)
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace()
@staticmethod
def check_image_quality(
datamodule,
reference_histogram_file,
compare_type="equal",
pearson_coeff_threshold=0.005,
):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
for split_name in ref_histogram_splits:
raw_samples = datamodule.splits[split_name][0][0]
# It is not possible to get a sample from a Dataset by name/path,
# only by index. This creates a dict of sample name to dataset
# index.
raw_samples_indices = {}
for idx, rs in enumerate(raw_samples):
raw_samples_indices[rs[0]] = idx
for ref_hist_path, ref_hist_data in ref_histogram_splits[
split_name
]:
# Get index in the dataset that will return the data
# corresponding to the specified sample name
dataset_sample_index = raw_samples_indices[ref_hist_path]
image_tensor = datamodule._datasets[split_name][ # noqa: SLF001
dataset_sample_index
][0]
histogram = []
for color_channel in image_tensor:
color_channel = numpy.multiply(
color_channel.numpy(),
255,
).astype(int)
histogram.extend(
numpy.histogram(
color_channel,
bins=256,
range=(0, 256),
)[0].tolist(),
)
if compare_type == "statistical":
# Compute pearson coefficients between histogram and
# reference and check the similarity within a certain
# threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
assert (
1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
)
else:
assert histogram == ref_hist_data
@pytest.fixture
def database_checkers():
return DatabaseCheckers
...@@ -4,14 +4,8 @@ ...@@ -4,14 +4,8 @@
import json import json
import pathlib import pathlib
import typing
import numpy
import numpy.typing
import pytest import pytest
import torch
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit
@pytest.fixture @pytest.fixture
...@@ -31,251 +25,3 @@ def datadir(request) -> pathlib.Path: ...@@ -31,251 +25,3 @@ def datadir(request) -> pathlib.Path:
""" """
return pathlib.Path(request.module.__file__).parents[0] / "data" return pathlib.Path(request.module.__file__).parents[0] / "data"
def pytest_configure(config):
"""This function is run once for pytest setup.
Parameters
----------
config
Configuration values. Check the pytest documentation for more
information.
"""
config.addinivalue_line(
"markers",
"skip_if_rc_var_not_set(name): this mark skips the test if a certain "
"~/.config/mednet.libs.classification.toml variable is not set",
)
config.addinivalue_line("markers", "slow: this mark indicates slow tests")
def pytest_runtest_setup(item):
"""This function is run for every test candidate in this directory.
The test is run if this function returns ``None``. To skip a test,
call ``pytest.skip()``, specifying a reason.
Parameters
----------
item
A test invocation item. Check the pytest documentation for more
information.
"""
from mednet.libs.classification.utils.rc import load_rc
rc = load_rc()
# iterates over all markers for the item being examined, get the first
# argument and accumulate these names
rc_names = [
mark.args[0] for mark in item.iter_markers(name="skip_if_rc_var_not_set")
]
# checks all names mentioned are set in ~/.config/mednet.libs.classification.toml, otherwise,
# skip the test
if rc_names:
missing = [k for k in rc_names if rc.get(k) is None]
if any(missing):
pytest.skip(
f"Test skipped because {', '.join(missing)} is **not** "
f"set in ~/.config/mednet.libs.classification.toml",
)
def rc_variable_set(name):
from mednet.libs.classification.utils.rc import load_rc
rc = load_rc()
pytest.mark.skipif(
name not in rc,
reason=f"RC variable '{name}' is not set",
)
@pytest.fixture(scope="session")
def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli")
class DatabaseCheckers:
"""Helpers for database tests."""
@staticmethod
def check_split(
split: DatabaseSplit,
lengths: dict[str, int],
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
) -> None:
"""Run a simple consistency check on the data split.
Parameters
----------
split
An instance of DatabaseSplit.
lengths
A dictionary that contains keys matching those of the split (this
will be checked). The values of the dictionary should correspond
to the sizes of each of the datasets in the split.
prefixes
Each file named in a split should start with at least one of these
prefixes.
possible_labels
These are the list of possible labels contained in any split.
"""
assert len(split) == len(lengths)
for k in lengths.keys():
# dataset must have been declared
assert k in split
assert len(split[k]) == lengths[k]
for s in split[k]:
assert any([s[0].startswith(k) for k in prefixes]), (
f"Sample with name {s[0]} does not start with any of the "
f"prefixes in {prefixes}"
)
if isinstance(s[1], list):
assert all([k in possible_labels for k in s[1]])
else:
assert s[1] in possible_labels
@staticmethod
def check_loaded_batch(
batch,
batch_size: int,
color_planes: int,
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
expected_num_labels: int,
expected_image_shape: tuple[int, ...] | None = None,
) -> None:
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
The mini-batch size.
color_planes
The number of color planes in the images.
prefixes
Each file named in a split should start with at least one of these
prefixes.
possible_labels
These are the list of possible labels contained in any split.
expected_num_labels
The expected number of labels each sample should have.
expected_image_shape
The expected shape of the image (num_channels, width, height).
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes
if expected_image_shape:
assert all(
[data.shape == expected_image_shape for data in batch[0]],
)
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
assert "target" in batch[1]
assert all([k in possible_labels for k in batch[1]["target"]])
if expected_num_labels:
assert len(batch[1]["target"]) == expected_num_labels
assert "name" in batch[1]
assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]],
)
# use the code below to view generated images
# from torchvision.transforms.functional import to_pil_image
# to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace()
@staticmethod
def _make_histo(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]:
from itertools import chain
def _mk_single_channel(data: numpy.typing.NDArray[numpy.uint8]) -> list[int]:
return numpy.histogram(data, bins=256, range=(0, 256))[0].tolist()
return list(chain(*[_mk_single_channel(k) for k in data[0, :]]))
@staticmethod
def check_image_quality(
datamodule,
reference_histogram_file: pathlib.Path,
compare_type: typing.Literal["equal", "statistical"] = "equal",
pearson_coeff_threshold: float = 0.005,
) -> None:
reference = {}
for split, values in JSONDatabaseSplit(reference_histogram_file).items():
reference[split] = {k[0]: k[1] for k in values}
for split_name, loader in datamodule.predict_dataloader().items():
for sample in loader:
uint8_array = (255 * sample[0]).byte().numpy()
histogram = DatabaseCheckers._make_histo(uint8_array)
if sample[1]["name"][0] in reference[split_name]:
ref_histogram = reference[split_name].pop(sample[1]["name"][0])
else:
continue
if compare_type == "statistical":
# Compute pearson coefficients between histogram and
# reference and check the similarity within a certain
# threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_histogram)
assert (1 - pearson_coeff_threshold) <= pearson_coeffs[0][1] <= 1
else:
assert histogram == ref_histogram, (
f"Current and reference histograms for sample "
f"`{sample[1]['name'][0]}` at split `{split_name}` "
f"do not match: current = {histogram}, "
f"reference = {ref_histogram}"
)
# all references must have been consumed
for ref_split, left_values in reference.items():
assert len(left_values) == 0, (
f"Not all references at split `{ref_split}` were consumed: "
f"{len(left_values)} are left"
)
@staticmethod
def write_image_quality_histogram(
datamodule,
reference_histogram_file: pathlib.Path,
) -> None:
data: dict[str, list] = {}
for split_name, loader in datamodule.predict_dataloader().items():
data[split_name] = []
for sample in loader:
uint8_array = (255 * sample[0]).byte().numpy()
data[split_name].append(
[sample[1]["name"][0], DatabaseCheckers._make_histo(uint8_array)]
)
with reference_histogram_file.open("w") as f:
json.dump(data, f)
@pytest.fixture
def database_checkers():
return DatabaseCheckers
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import typing
import numpy
import pytest
import torch
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit
@pytest.fixture
def datadir(request) -> pathlib.Path:
"""Return the directory in which the test is sitting. Check the pytest
documentation for more information.
Parameters
----------
request
Information of the requesting test function.
Returns
-------
pathlib.Path
The directory in which the test is sitting.
"""
return pathlib.Path(request.module.__file__).parents[0] / "data"
def pytest_configure(config):
"""This function is run once for pytest setup.
Parameters
----------
config
Configuration values. Check the pytest documentation for more
information.
"""
config.addinivalue_line(
"markers",
"skip_if_rc_var_not_set(name): this mark skips the test if a certain "
"~/.config/mednet.libs.segmentation.toml variable is not set",
)
config.addinivalue_line("markers", "slow: this mark indicates slow tests")
def pytest_runtest_setup(item):
"""This function is run for every test candidate in this directory.
The test is run if this function returns ``None``. To skip a test,
call ``pytest.skip()``, specifying a reason.
Parameters
----------
item
A test invocation item. Check the pytest documentation for more
information.
"""
from mednet.libs.segmentation.utils.rc import load_rc
rc = load_rc()
# iterates over all markers for the item being examined, get the first
# argument and accumulate these names
rc_names = [
mark.args[0]
for mark in item.iter_markers(name="skip_if_rc_var_not_set")
]
# checks all names mentioned are set in ~/.config/mednet.libs.segmentation.toml, otherwise,
# skip the test
if rc_names:
missing = [k for k in rc_names if rc.get(k) is None]
if any(missing):
pytest.skip(
f"Test skipped because {', '.join(missing)} is **not** "
f"set in ~/.config/mednet.libs.segmentation.toml",
)
def rc_variable_set(name):
from mednet.libs.segmentation.utils.rc import load_rc
rc = load_rc()
pytest.mark.skipif(
name not in rc,
reason=f"RC variable '{name}' is not set",
)
@pytest.fixture(scope="session")
def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli")
class DatabaseCheckers:
"""Helpers for database tests."""
@staticmethod
def check_split(
split: DatabaseSplit,
lengths: dict[str, int],
prefixes: typing.Sequence[str] = None,
):
"""Run a simple consistency check on the data split.
Parameters
----------
split
An instance of DatabaseSplit.
lengths
A dictionary that contains keys matching those of the split (this
will be checked). The values of the dictionary should correspond
to the sizes of each of the datasets in the split.
prefixes
Each file named in a split should start with at least one of these
prefixes.
"""
assert len(split) == len(lengths)
for k in lengths.keys():
# dataset must have been declared
assert k in split
assert len(split[k]) == lengths[k]
for s in split[k]:
if prefixes is not None:
assert any([s[0].startswith(k) for k in prefixes]), (
f"Sample with name {s[0]} does not start with any of the "
f"prefixes in {prefixes}"
)
@staticmethod
def check_loaded_batch(
batch,
batch_size: int,
color_planes: int,
expected_num_targets: int,
prefixes: typing.Sequence[str] = None,
expected_image_shape: tuple[int, ...] | None = None,
):
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
The mini-batch size.
color_planes
The number of color planes in the images.
expected_num_targets
The expected number of labels each sample should have.
prefixes
Each file named in a split should start with at least one of these
prefixes.
expected_image_shape
The expected shape of the image (num_channels, width, height).
"""
assert len(batch) == 2 # sample, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes
assert all([isinstance(image, torch.Tensor) for image in batch[0]])
if expected_image_shape:
assert all(
[data.shape == expected_image_shape for data in batch[0]],
)
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) in [2, 3] # target, Optional(mask), name
assert "target" in batch[1]
assert all(
[isinstance(target, torch.Tensor) for target in batch[1]["target"]]
)
if expected_num_targets:
assert len(batch[1]["target"]) == expected_num_targets
if "mask" in batch[1]:
assert all(
[isinstance(mask, torch.Tensor) for mask in batch[1]["mask"]]
)
assert "name" in batch[1]
if prefixes is not None:
assert all(
[
any([k.startswith(j) for j in prefixes])
for k in batch[1]["name"]
],
)
@staticmethod
def check_image_quality(
datamodule,
reference_histogram_file,
compare_type="equal",
pearson_coeff_threshold=0.005,
):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
for split_name in ref_histogram_splits:
raw_samples = datamodule.splits[split_name][0][0]
# It is not possible to get a sample from a Dataset by name/path,
# only by index. This creates a dict of sample name to dataset
# index.
raw_samples_indices = {}
for idx, rs in enumerate(raw_samples):
raw_samples_indices[rs[0]] = idx
for ref_hist_path, ref_hist_data in ref_histogram_splits[
split_name
]:
# Get index in the dataset that will return the data
# corresponding to the specified sample name
dataset_sample_index = raw_samples_indices[ref_hist_path]
image_tensor = datamodule._datasets[split_name][ # noqa: SLF001
dataset_sample_index
][0]
histogram = []
for color_channel in image_tensor:
color_channel = numpy.multiply(
color_channel.numpy(),
255,
).astype(int)
histogram.extend(
numpy.histogram(
color_channel,
bins=256,
range=(0, 256),
)[0].tolist(),
)
if compare_type == "statistical":
# Compute pearson coefficients between histogram and
# reference and check the similarity within a certain
# threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
assert (
1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
)
else:
assert histogram == ref_hist_data
@pytest.fixture
def database_checkers():
return DatabaseCheckers
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for drive dataset."""
import importlib
import pytest
from click.testing import CliRunner
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"split,lengths",
[
("default", dict(train=20, test=20)),
],
ids=id_function, # just changes how pytest prints it
)
def test_protocol_consistency(
database_checkers,
split: str,
lengths: dict[str, int],
):
from mednet.libs.segmentation.config.data.drive.datamodule import (
make_split,
)
database_checkers.check_split(
make_split(f"{split}.json"),
lengths=lengths,
)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_database_check():
from mednet.libs.segmentation.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["drive"])
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
@pytest.mark.parametrize(
"dataset",
[
"train",
"test",
],
)
@pytest.mark.parametrize(
"name",
[
"default",
],
)
def test_loading(database_checkers, name: str, dataset: str):
datamodule = importlib.import_module(
f".{name}",
"mednet.libs.segmentation.config.data.drive",
).datamodule
datamodule.model_transforms = [] # should be done before setup()
datamodule.setup("predict") # sets up all datasets
loader = datamodule.predict_dataloader()[dataset]
limit = 3 # limit load checking
for batch in loader:
if limit == 0:
break
database_checkers.check_loaded_batch(
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
)
limit -= 1
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_drive_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.drive",
).datamodule
datamodule.model_transforms = []
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
@pytest.mark.parametrize(
"model_name",
["lwnet"],
)
def test_model_transforms_image_quality(database_checkers, datadir, model_name):
# Densenet's model.name is "densenet-212" and does not correspond to its module name.
if model_name == "densenet":
reference_histogram_file = str(
datadir
/ "histograms/models/histograms_densenet-121_drive_default.json",
)
else:
reference_histogram_file = str(
datadir
/ f"histograms/models/histograms_{model_name}_drive_default.json",
)
datamodule = importlib.import_module(
".default",
"mednet.libs.segmentation.config.data.drive",
).datamodule
model = importlib.import_module(
f".{model_name}",
"mednet.libs.segmentation.config.models",
).model
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment