Skip to content
Snippets Groups Projects
Commit 451950dd authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Revamp tests; Delete outdated; Make common conftest.py; Improve segmentation...

Revamp tests; Delete outdated; Make common conftest.py; Improve segmentation tests; Allow segmentation tests to run on CI
parent 9a399fad
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 81 additions and 266 deletions
File moved
File moved
[datadir]
# classification
indian = "/idiap/resource/database/TBXpredict"
tbx11k = "/idiap/resource/database/tbx11k"
# segmentation
chasedb1 = "/idiap/resource/database/CHASE-DB1"
drive = "/idiap/resource/database/DRIVE"
hrf = "/idiap/resource/database/HRF"
jsrt = "/idiap/resource/database/JSRT"
stare = "/idiap/resource/database/STARE"
# classification and segmentation
#montgomery = "/idiap/resource/database/MontgomeryXraySet"
montgomery = "/idiap/resource/database/montgomery-preprocessed"
shenzhen = "/idiap/resource/database/ShenzhenXraySet"
indian = "/idiap/resource/database/TBXpredict"
tbx11k = "/idiap/resource/database/tbx11k"
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import typing
import numpy
import pytest
import torch
from mednet.libs.common.data.split import JSONDatabaseSplit
from mednet.libs.common.data.typing import DatabaseSplit
@pytest.fixture
def datadir(request) -> pathlib.Path:
"""Return the directory in which the test is sitting. Check the pytest
documentation for more information.
Parameters
----------
request
Information of the requesting test function.
Returns
-------
pathlib.Path
The directory in which the test is sitting.
"""
return pathlib.Path(request.module.__file__).parents[0] / "data"
def pytest_configure(config):
"""This function is run once for pytest setup.
Parameters
----------
config
Configuration values. Check the pytest documentation for more
information.
"""
config.addinivalue_line(
"markers",
"skip_if_rc_var_not_set(name): this mark skips the test if a certain "
"~/.config/mednet.libs.segmentation.toml variable is not set",
)
config.addinivalue_line("markers", "slow: this mark indicates slow tests")
def pytest_runtest_setup(item):
"""This function is run for every test candidate in this directory.
The test is run if this function returns ``None``. To skip a test,
call ``pytest.skip()``, specifying a reason.
Parameters
----------
item
A test invocation item. Check the pytest documentation for more
information.
"""
from mednet.libs.common.utils.rc import load_rc
rc = load_rc()
# iterates over all markers for the item being examined, get the first
# argument and accumulate these names
rc_names = [
mark.args[0] for mark in item.iter_markers(name="skip_if_rc_var_not_set")
]
# checks all names mentioned are set in ~/.config/mednet.libs.segmentation.toml, otherwise,
# skip the test
if rc_names:
missing = [k for k in rc_names if rc.get(k) is None]
if any(missing):
pytest.skip(
f"Test skipped because {', '.join(missing)} is **not** "
f"set in ~/.config/mednet.libs.segmentation.toml",
)
def rc_variable_set(name):
from mednet.libs.common.utils.rc import load_rc
rc = load_rc()
pytest.mark.skipif(
name not in rc,
reason=f"RC variable '{name}' is not set",
)
@pytest.fixture(scope="session")
def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli")
class DatabaseCheckers:
"""Helpers for database tests."""
@staticmethod
def check_split(
split: DatabaseSplit,
lengths: dict[str, int],
prefixes: typing.Sequence[str] = None,
):
"""Run a simple consistency check on the data split.
Parameters
----------
split
An instance of DatabaseSplit.
lengths
A dictionary that contains keys matching those of the split (this
will be checked). The values of the dictionary should correspond
to the sizes of each of the datasets in the split.
prefixes
Each file named in a split should start with at least one of these
prefixes.
"""
assert len(split) == len(lengths)
for k in lengths.keys():
# dataset must have been declared
assert k in split
assert len(split[k]) == lengths[k]
for s in split[k]:
if prefixes is not None:
assert any([s[0].startswith(k) for k in prefixes]), (
f"Sample with name {s[0]} does not start with any of the "
f"prefixes in {prefixes}"
)
@staticmethod
def check_loaded_batch(
batch,
batch_size: int,
color_planes: int,
expected_num_targets: int,
prefixes: typing.Sequence[str] = None,
expected_image_shape: tuple[int, ...] | None = None,
):
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
The mini-batch size.
color_planes
The number of color planes in the images.
expected_num_targets
The expected number of labels each sample should have.
prefixes
Each file named in a split should start with at least one of these
prefixes.
expected_image_shape
The expected shape of the image (num_channels, width, height).
"""
assert len(batch) == 2 # sample, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes
assert all([isinstance(image, torch.Tensor) for image in batch[0]])
if expected_image_shape:
assert all(
[data.shape == expected_image_shape for data in batch[0]],
)
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) in [2, 3] # target, Optional(mask), name
assert "target" in batch[1]
assert all([isinstance(target, torch.Tensor) for target in batch[1]["target"]])
if expected_num_targets:
assert len(batch[1]["target"]) == expected_num_targets
if "mask" in batch[1]:
assert all([isinstance(mask, torch.Tensor) for mask in batch[1]["mask"]])
assert "name" in batch[1]
if prefixes is not None:
assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]],
)
@staticmethod
def check_image_quality(
datamodule,
reference_histogram_file,
compare_type="equal",
pearson_coeff_threshold=0.005,
):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
for split_name in ref_histogram_splits:
raw_samples = datamodule.splits[split_name][0][0]
# It is not possible to get a sample from a Dataset by name/path,
# only by index. This creates a dict of sample name to dataset
# index.
raw_samples_indices = {}
for idx, rs in enumerate(raw_samples):
raw_samples_indices[rs[0]] = idx
for ref_hist_path, ref_hist_data in ref_histogram_splits[split_name]:
# Get index in the dataset that will return the data
# corresponding to the specified sample name
dataset_sample_index = raw_samples_indices[ref_hist_path]
image_tensor = datamodule._datasets[split_name][ # noqa: SLF001
dataset_sample_index
][0]
histogram = []
for color_channel in image_tensor:
color_channel = numpy.multiply(
color_channel.numpy(),
255,
).astype(int)
histogram.extend(
numpy.histogram(
color_channel,
bins=256,
range=(0, 256),
)[0].tolist(),
)
if compare_type == "statistical":
# Compute pearson coefficients between histogram and
# reference and check the similarity within a certain
# threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
assert 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
else:
assert histogram == ref_hist_data
@pytest.fixture
def database_checkers():
return DatabaseCheckers
......@@ -33,6 +33,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.chasedb1", f"{split}.json"),
lengths=lengths,
prefixes=["Image_"],
possible_labels=[],
)
......@@ -81,7 +83,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=["Image_"],
possible_labels=[],
)
limit -= 1
......
......@@ -32,6 +32,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.cxr8", f"{split}.json"),
lengths=lengths,
prefixes=[],
possible_labels=[],
)
......@@ -79,7 +81,8 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
)
limit -= 1
......
......@@ -32,6 +32,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.drhagis", f"{split}.json"),
lengths=lengths,
prefixes=["Fundus_Images/"],
possible_labels=[],
)
......@@ -79,7 +81,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=["Fundus_Images/"],
possible_labels=[],
)
limit -= 1
......
......@@ -33,6 +33,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.drionsdb", f"{split}.json"),
lengths=lengths,
prefixes=["images/image_"],
possible_labels=[],
)
......@@ -81,7 +83,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=["images/image_"],
possible_labels=[],
)
limit -= 1
......
......@@ -33,6 +33,11 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.drishtigs1", f"{split}.json"),
lengths=lengths,
prefixes=[
"Drishti-GS1_files/Training/Images/drishtiGS_",
"Drishti-GS1_files/Test/Images/drishtiGS_",
],
possible_labels=[],
)
......@@ -83,7 +88,13 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=[
"Drishti-GS1_files/Training/Images/drishtiGS_",
"Drishti-GS1_files/Test/Images/drishtiGS_",
],
possible_labels=[],
)
limit -= 1
......
......@@ -33,6 +33,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.drive", f"{split}.json"),
lengths=lengths,
prefixes=["training/", "test/"],
possible_labels=[],
)
......@@ -80,7 +82,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
possible_labels=[],
prefixes=["training/", "test/"],
)
limit -= 1
......
......@@ -32,6 +32,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.hrf", f"{split}.json"),
lengths=lengths,
prefixes=["images/"],
possible_labels=[],
)
......@@ -79,7 +81,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=["images/"],
possible_labels=[],
)
limit -= 1
......
......@@ -33,6 +33,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.iostar", f"{split}.json"),
lengths=lengths,
prefixes=["image/STAR "],
possible_labels=[],
)
......@@ -78,7 +80,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=["image/STAR "],
possible_labels=[],
)
limit -= 1
......
......@@ -32,6 +32,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.jsrt", f"{split}.json"),
lengths=lengths,
prefixes=["All247images/JPC"],
possible_labels=[],
)
......@@ -80,7 +82,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=["All247images/JPC"],
possible_labels=[],
)
limit -= 1
......
......@@ -32,6 +32,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.montgomery", f"{split}.json"),
lengths=lengths,
prefixes=["CXR_png"],
possible_labels=[],
)
......@@ -80,8 +82,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=["CXR_png"],
possible_labels=[],
)
limit -= 1
......
......@@ -32,6 +32,8 @@ def test_protocol_consistency(
database_checkers.check_split(
make_split("mednet.libs.segmentation.config.data.refuge", f"{split}.json"),
lengths=lengths,
prefixes=["Training400/", "REFUGE-Validation400/V", "Test400/T0"],
possible_labels=[],
)
......@@ -81,7 +83,10 @@ def test_loading(database_checkers, name: str, dataset: str):
batch,
batch_size=1,
color_planes=3,
expected_num_targets=1,
expected_num_labels=1,
expected_meta_size=3,
prefixes=["Training400/", "REFUGE-Validation400/V", "Test400/T0"],
possible_labels=[],
)
limit -= 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment