Skip to content
Snippets Groups Projects
Commit 147da72b authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Merge branch 'ci-db-paths' into 'update-tests'

Fix Issues when running tests on the CI

See merge request biosignal/software/mednet!19
parents f8a0d1dc f3f94d64
No related branches found
No related tags found
3 merge requests!19Fix Issues when running tests on the CI,!18Update tests,!16Make square centre-padding a model transform
Pipeline #84277 failed
...@@ -10,6 +10,7 @@ include: ...@@ -10,6 +10,7 @@ include:
variables: variables:
GIT_SUBMODULE_STRATEGY: normal GIT_SUBMODULE_STRATEGY: normal
GIT_SUBMODULE_DEPTH: 1 GIT_SUBMODULE_DEPTH: 1
XDG_CONFIG_HOME: $CI_PROJECT_DIR/tests/data
documentation: documentation:
before_script: before_script:
......
...@@ -21,5 +21,6 @@ Files: ...@@ -21,5 +21,6 @@ Files:
tests/data/*.csv tests/data/*.csv
tests/data/*.json tests/data/*.json
tests/data/*.png tests/data/*.png
tests/*.toml
Copyright: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> Copyright: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
License: GPL-3.0-or-later License: GPL-3.0-or-later
...@@ -75,7 +75,11 @@ class Pasa(pl.LightningModule): ...@@ -75,7 +75,11 @@ class Pasa(pl.LightningModule):
self.model_transforms = [ self.model_transforms = [
Grayscale(), Grayscale(),
SquareCenterPad(), SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True), torchvision.transforms.Resize(
512,
antialias=True,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
),
] ]
self._train_loss = train_loss self._train_loss = train_loss
......
...@@ -2,18 +2,13 @@ ...@@ -2,18 +2,13 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import os
import pathlib import pathlib
import tempfile
import typing import typing
import zipfile
import numpy
import pytest import pytest
import tomli_w
import torch import torch
from torchvision.transforms.functional import to_pil_image
from mednet.data.split import JSONDatabaseSplit from mednet.data.split import JSONDatabaseSplit
from mednet.data.typing import DatabaseSplit from mednet.data.typing import DatabaseSplit
...@@ -100,55 +95,6 @@ def temporary_basedir(tmp_path_factory): ...@@ -100,55 +95,6 @@ def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli") return tmp_path_factory.mktemp("test-cli")
def pytest_sessionstart(session: pytest.Session) -> None:
"""Preset the session start to ensure the Montgomery dataset is always available.
Parameters
----------
session
The session to use.
"""
from mednet.utils.rc import load_rc
rc = load_rc()
database_dir = rc.get("datadir.montgomery")
if database_dir is not None:
# if the user downloaded it, use that copy
return
# else, we must extract the LFS component (we are likely on the CI)
archive = (
pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip"
)
assert archive.exists(), (
f"Neither datadir.montgomery is set on the global configuration, "
f"(typically ~/.config/mednet.toml), or it is possible to detect "
f"the presence of {archive}' (did you git submodule init --update "
f"this submodule?)"
)
montgomery_tempdir = tempfile.TemporaryDirectory()
rc.setdefault("datadir.montgomery", montgomery_tempdir.name)
with zipfile.ZipFile(archive) as zf:
zf.extractall(montgomery_tempdir.name)
config_filename = "mednet.toml"
with open(
os.path.join(montgomery_tempdir.name, config_filename), "wb"
) as f:
tomli_w.dump(rc.data, f)
f.flush()
os.environ["XDG_CONFIG_HOME"] = montgomery_tempdir.name
# stash the newly created temporary directory so we can erase it when the
key = pytest.StashKey[tempfile.TemporaryDirectory]()
session.stash[key] = montgomery_tempdir
class DatabaseCheckers: class DatabaseCheckers:
"""Helpers for database tests.""" """Helpers for database tests."""
...@@ -256,7 +202,12 @@ class DatabaseCheckers: ...@@ -256,7 +202,12 @@ class DatabaseCheckers:
# __import__("pdb").set_trace() # __import__("pdb").set_trace()
@staticmethod @staticmethod
def check_image_quality(datamodule, reference_histogram_file): def check_image_quality(
datamodule,
reference_histogram_file,
compare_type="equal",
pearson_coeff_threshold=0.005,
):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
for split_name in ref_histogram_splits: for split_name in ref_histogram_splits:
...@@ -277,10 +228,28 @@ class DatabaseCheckers: ...@@ -277,10 +228,28 @@ class DatabaseCheckers:
image_tensor = datamodule._datasets[split_name][ image_tensor = datamodule._datasets[split_name][
dataset_sample_index dataset_sample_index
][0] ][0]
img = to_pil_image(image_tensor)
histogram = img.histogram()
assert histogram == ref_hist_data histogram = []
for color_channel in image_tensor:
color_channel = numpy.multiply(
color_channel.numpy(), 255
).astype(int)
histogram.extend(
numpy.histogram(
color_channel, bins=256, range=(0, 256)
)[0].tolist()
)
if compare_type == "statistical":
# Compute pearson coefficients between histogram and reference
# and check the similarity within a certain threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
assert (
1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
)
else:
assert histogram == ref_hist_data
@pytest.fixture @pytest.fixture
......
[datadir]
montgomery = "/idiap/resource/database/MontgomeryXraySet"
shenzhen = "/idiap/resource/database/ShenzhenXraySet"
indian = "/idiap/resource/database/TBXpredict"
tbx11k = "/idiap/resource/database/tbx11k"
...@@ -95,7 +95,7 @@ def test_loading(database_checkers, name: str, dataset: str): ...@@ -95,7 +95,7 @@ def test_loading(database_checkers, name: str, dataset: str):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_loaded_image_quality(database_checkers, datadir): def test_raw_transforms_image_quality(database_checkers, datadir):
reference_histogram_file = str( reference_histogram_file = str(
datadir / "histograms/raw_data/histograms_montgomery_default.json" datadir / "histograms/raw_data/histograms_montgomery_default.json"
) )
...@@ -142,4 +142,10 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name): ...@@ -142,4 +142,10 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name):
datamodule.model_transforms = model.model_transforms datamodule.model_transforms = model.model_transforms
datamodule.setup("predict") datamodule.setup("predict")
database_checkers.check_image_quality(datamodule, reference_histogram_file)
database_checkers.check_image_quality(
datamodule,
reference_histogram_file,
compare_type="statistical",
pearson_coeff_threshold=0.005,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment