Skip to content
Snippets Groups Projects
Commit 147da72b authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Merge branch 'ci-db-paths' into 'update-tests'

Fix Issues when running tests on the CI

See merge request biosignal/software/mednet!19
parents f8a0d1dc f3f94d64
No related branches found
No related tags found
3 merge requests!19Fix Issues when running tests on the CI,!18Update tests,!16Make square centre-padding a model transform
Pipeline #84277 failed
......@@ -10,6 +10,7 @@ include:
variables:
GIT_SUBMODULE_STRATEGY: normal
GIT_SUBMODULE_DEPTH: 1
XDG_CONFIG_HOME: $CI_PROJECT_DIR/tests/data
documentation:
before_script:
......
......@@ -21,5 +21,6 @@ Files:
tests/data/*.csv
tests/data/*.json
tests/data/*.png
tests/*.toml
Copyright: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
License: GPL-3.0-or-later
......@@ -75,7 +75,11 @@ class Pasa(pl.LightningModule):
self.model_transforms = [
Grayscale(),
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
torchvision.transforms.Resize(
512,
antialias=True,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
),
]
self._train_loss = train_loss
......
......@@ -2,18 +2,13 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
import pathlib
import tempfile
import typing
import zipfile
import numpy
import pytest
import tomli_w
import torch
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit
from mednet.data.typing import DatabaseSplit
......@@ -100,55 +95,6 @@ def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli")
def pytest_sessionstart(session: pytest.Session) -> None:
"""Preset the session start to ensure the Montgomery dataset is always available.
Parameters
----------
session
The session to use.
"""
from mednet.utils.rc import load_rc
rc = load_rc()
database_dir = rc.get("datadir.montgomery")
if database_dir is not None:
# if the user downloaded it, use that copy
return
# else, we must extract the LFS component (we are likely on the CI)
archive = (
pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip"
)
assert archive.exists(), (
f"Neither datadir.montgomery is set on the global configuration, "
f"(typically ~/.config/mednet.toml), or it is possible to detect "
f"the presence of {archive}' (did you git submodule init --update "
f"this submodule?)"
)
montgomery_tempdir = tempfile.TemporaryDirectory()
rc.setdefault("datadir.montgomery", montgomery_tempdir.name)
with zipfile.ZipFile(archive) as zf:
zf.extractall(montgomery_tempdir.name)
config_filename = "mednet.toml"
with open(
os.path.join(montgomery_tempdir.name, config_filename), "wb"
) as f:
tomli_w.dump(rc.data, f)
f.flush()
os.environ["XDG_CONFIG_HOME"] = montgomery_tempdir.name
# stash the newly created temporary directory so we can erase it when the
key = pytest.StashKey[tempfile.TemporaryDirectory]()
session.stash[key] = montgomery_tempdir
class DatabaseCheckers:
"""Helpers for database tests."""
......@@ -256,7 +202,12 @@ class DatabaseCheckers:
# __import__("pdb").set_trace()
@staticmethod
def check_image_quality(datamodule, reference_histogram_file):
def check_image_quality(
datamodule,
reference_histogram_file,
compare_type="equal",
pearson_coeff_threshold=0.005,
):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
for split_name in ref_histogram_splits:
......@@ -277,10 +228,28 @@ class DatabaseCheckers:
image_tensor = datamodule._datasets[split_name][
dataset_sample_index
][0]
img = to_pil_image(image_tensor)
histogram = img.histogram()
assert histogram == ref_hist_data
histogram = []
for color_channel in image_tensor:
color_channel = numpy.multiply(
color_channel.numpy(), 255
).astype(int)
histogram.extend(
numpy.histogram(
color_channel, bins=256, range=(0, 256)
)[0].tolist()
)
if compare_type == "statistical":
# Compute pearson coefficients between histogram and reference
# and check the similarity within a certain threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
assert (
1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
)
else:
assert histogram == ref_hist_data
@pytest.fixture
......
[datadir]
montgomery = "/idiap/resource/database/MontgomeryXraySet"
shenzhen = "/idiap/resource/database/ShenzhenXraySet"
indian = "/idiap/resource/database/TBXpredict"
tbx11k = "/idiap/resource/database/tbx11k"
......@@ -95,7 +95,7 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loaded_image_quality(database_checkers, datadir):
def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_montgomery_default.json"
)
......@@ -142,4 +142,10 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name):
datamodule.model_transforms = model.model_transforms
datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment