diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c9ae0016c90b88658cbab481bf880ef729a38ce1..25425691057398ffc6c65bda57575ade1d40c97e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -10,6 +10,7 @@ include: variables: GIT_SUBMODULE_STRATEGY: normal GIT_SUBMODULE_DEPTH: 1 + XDG_CONFIG_HOME: $CI_PROJECT_DIR/tests/data documentation: before_script: diff --git a/.reuse/dep5 b/.reuse/dep5 index 5314ea6e597528ba571ecdaa3e5686b76102efe5..fb27b959e3ce20f1d74e7f560f752367575d2924 100644 --- a/.reuse/dep5 +++ b/.reuse/dep5 @@ -21,5 +21,6 @@ Files: tests/data/*.csv tests/data/*.json tests/data/*.png + tests/*.toml Copyright: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> License: GPL-3.0-or-later diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index 38ad0218ffd8a0a35573896175d32d9ff6a521af..e285bf7baa36a9ab548de01451622ee61ae1c111 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -75,7 +75,11 @@ class Pasa(pl.LightningModule): self.model_transforms = [ Grayscale(), SquareCenterPad(), - torchvision.transforms.Resize(512, antialias=True), + torchvision.transforms.Resize( + 512, + antialias=True, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ), ] self._train_loss = train_loss diff --git a/tests/conftest.py b/tests/conftest.py index ef45a883f7b7797facac85b4446574c64b806ebe..b51eafa681ca9c95be145a2c5eedea0b1fe5a5b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,18 +2,13 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import os import pathlib -import tempfile import typing -import zipfile +import numpy import pytest -import tomli_w import torch -from torchvision.transforms.functional import to_pil_image - from mednet.data.split import JSONDatabaseSplit from mednet.data.typing import DatabaseSplit @@ -100,55 +95,6 @@ def temporary_basedir(tmp_path_factory): return tmp_path_factory.mktemp("test-cli") -def pytest_sessionstart(session: pytest.Session) -> None: - """Preset the session start to ensure the Montgomery dataset is always available. - - Parameters - ---------- - session - The session to use. - """ - - from mednet.utils.rc import load_rc - - rc = load_rc() - - database_dir = rc.get("datadir.montgomery") - if database_dir is not None: - # if the user downloaded it, use that copy - return - - # else, we must extract the LFS component (we are likely on the CI) - archive = ( - pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip" - ) - assert archive.exists(), ( - f"Neither datadir.montgomery is set on the global configuration, " - f"(typically ~/.config/mednet.toml), or it is possible to detect " - f"the presence of {archive}' (did you git submodule init --update " - f"this submodule?)" - ) - - montgomery_tempdir = tempfile.TemporaryDirectory() - rc.setdefault("datadir.montgomery", montgomery_tempdir.name) - - with zipfile.ZipFile(archive) as zf: - zf.extractall(montgomery_tempdir.name) - - config_filename = "mednet.toml" - with open( - os.path.join(montgomery_tempdir.name, config_filename), "wb" - ) as f: - tomli_w.dump(rc.data, f) - f.flush() - - os.environ["XDG_CONFIG_HOME"] = montgomery_tempdir.name - - # stash the newly created temporary directory so we can erase it when the - key = pytest.StashKey[tempfile.TemporaryDirectory]() - session.stash[key] = montgomery_tempdir - - class DatabaseCheckers: """Helpers for database tests.""" @@ -256,7 +202,12 @@ class DatabaseCheckers: # __import__("pdb").set_trace() @staticmethod - def check_image_quality(datamodule, reference_histogram_file): + def check_image_quality( + datamodule, + reference_histogram_file, + compare_type="equal", + pearson_coeff_threshold=0.005, + ): ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file) for split_name in ref_histogram_splits: @@ -277,10 +228,28 @@ class DatabaseCheckers: image_tensor = datamodule._datasets[split_name][ dataset_sample_index ][0] - img = to_pil_image(image_tensor) - histogram = img.histogram() - assert histogram == ref_hist_data + histogram = [] + for color_channel in image_tensor: + color_channel = numpy.multiply( + color_channel.numpy(), 255 + ).astype(int) + histogram.extend( + numpy.histogram( + color_channel, bins=256, range=(0, 256) + )[0].tolist() + ) + + if compare_type == "statistical": + # Compute pearson coefficients between histogram and reference + # and check the similarity within a certain threshold + pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data) + assert ( + 1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1 + ) + + else: + assert histogram == ref_hist_data @pytest.fixture diff --git a/tests/data/mednet.toml b/tests/data/mednet.toml new file mode 100644 index 0000000000000000000000000000000000000000..1eccba371435f50739c65227ae46bbae4d33375a --- /dev/null +++ b/tests/data/mednet.toml @@ -0,0 +1,5 @@ +[datadir] +montgomery = "/idiap/resource/database/MontgomeryXraySet" +shenzhen = "/idiap/resource/database/ShenzhenXraySet" +indian = "/idiap/resource/database/TBXpredict" +tbx11k = "/idiap/resource/database/tbx11k" diff --git a/tests/test_montgomery.py b/tests/test_montgomery.py index 44ff9e5141dcd88091c84d3a012dc9e57a2c9645..4ce6258064a64507ccdc2714b9aa16c957961a36 100644 --- a/tests/test_montgomery.py +++ b/tests/test_montgomery.py @@ -95,7 +95,7 @@ def test_loading(database_checkers, name: str, dataset: str): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") -def test_loaded_image_quality(database_checkers, datadir): +def test_raw_transforms_image_quality(database_checkers, datadir): reference_histogram_file = str( datadir / "histograms/raw_data/histograms_montgomery_default.json" ) @@ -142,4 +142,10 @@ def test_model_transforms_image_quality(database_checkers, datadir, model_name): datamodule.model_transforms = model.model_transforms datamodule.setup("predict") - database_checkers.check_image_quality(datamodule, reference_histogram_file) + + database_checkers.check_image_quality( + datamodule, + reference_histogram_file, + compare_type="statistical", + pearson_coeff_threshold=0.005, + )