# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import os import pathlib import tempfile import typing import zipfile import pytest import tomli_w import torch from mednet.data.typing import DatabaseSplit @pytest.fixture def datadir(request) -> pathlib.Path: """Returns the directory in which the test is sitting.""" return pathlib.Path(request.module.__file__).parents[0] / "data" def pytest_configure(config): """This function is run once for pytest setup.""" config.addinivalue_line( "markers", "skip_if_rc_var_not_set(name): this mark skips the test if a certain " "~/.config/mednet.toml variable is not set", ) config.addinivalue_line("markers", "slow: this mark indicates slow tests") def pytest_runtest_setup(item): """This function is run for every test candidate in this directory. The test is run if this function returns ``None``. To skip a test, call ``pytest.skip()``, specifying a reason. """ from mednet.utils.rc import load_rc rc = load_rc() # iterates over all markers for the item being examined, get the first # argument and accumulate these names rc_names = [ mark.args[0] for mark in item.iter_markers(name="skip_if_rc_var_not_set") ] # checks all names mentioned are set in ~/.config/mednet.toml, otherwise, # skip the test if rc_names: missing = [k for k in rc_names if rc.get(k) is None] if any(missing): pytest.skip( f"Test skipped because {', '.join(missing)} is **not** " f"set in ~/.config/mednet.toml" ) def rc_variable_set(name): from mednet.utils.rc import load_rc rc = load_rc() pytest.mark.skipif( name not in rc, reason=f"RC variable '{name}' is not set", ) @pytest.fixture(scope="session") def temporary_basedir(tmp_path_factory): return tmp_path_factory.mktemp("test-cli") def pytest_sessionstart(session: pytest.Session) -> None: """Presets the session start to ensure the Montgomery dataset is always available.""" from mednet.utils.rc import load_rc rc = load_rc() database_dir = rc.get("datadir.montgomery") if database_dir is not None: # if the user downloaded it, use that copy return # else, we must extract the LFS component (we are likely on the CI) archive = ( pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip" ) assert archive.exists(), ( f"Neither datadir.montgomery is set on the global configuration, " f"(typically ~/.config/mednet.toml), or it is possible to detect " f"the presence of {archive}' (did you git submodule init --update " f"this submodule?)" ) montgomery_tempdir = tempfile.TemporaryDirectory() rc.setdefault("datadir.montgomery", montgomery_tempdir.name) with zipfile.ZipFile(archive) as zf: zf.extractall(montgomery_tempdir.name) config_filename = "mednet.toml" with open( os.path.join(montgomery_tempdir.name, config_filename), "wb" ) as f: tomli_w.dump(rc.data, f) f.flush() os.environ["XDG_CONFIG_HOME"] = montgomery_tempdir.name # stash the newly created temporary directory so we can erase it when the key = pytest.StashKey[tempfile.TemporaryDirectory]() session.stash[key] = montgomery_tempdir class DatabaseCheckers: """Helpers for database tests.""" @staticmethod def check_split( split: DatabaseSplit, lengths: dict[str, int], prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], ): """Runs a simple consistence check on the data split. Parameters ---------- make_split A database specific function that takes a split name and returns the loaded database split. split_filename This is the split we will check lenghts A dictionary that contains keys matching those of the split (this will be checked). The values of the dictionary should correspond to the sizes of each of the datasets in the split. prefixes Each file named in a split should start with at least one of these prefixes. possible_labels These are the list of possible labels contained in any split. """ assert len(split) == len(lengths) for k in lengths.keys(): # dataset must have been declared assert k in split assert len(split[k]) == lengths[k] for s in split[k]: assert any([s[0].startswith(k) for k in prefixes]), ( f"Sample with name {s[0]} does not start with any of the " f"prefixes in {prefixes}" ) if isinstance(s[1], list): assert all([k in possible_labels for k in s[1]]) else: assert s[1] in possible_labels @staticmethod def check_loaded_batch( batch, batch_size: int, color_planes: int, prefixes: typing.Sequence[str], possible_labels: typing.Sequence[int], ): """Checks the consistence of an individual (loaded) batch. Parameters ---------- batch The loaded batch to be checked. size The mini-batch size prefixes Each file named in a split should start with at least one of these prefixes. possible_labels These are the list of possible labels contained in any split. """ assert len(batch) == 2 # data, metadata assert isinstance(batch[0], torch.Tensor) assert batch[0].shape[0] == batch_size # mini-batch size assert batch[0].shape[1] == color_planes # grayscale images assert batch[0].shape[2] == batch[0].shape[3] # image is square assert isinstance(batch[1], dict) # metadata assert len(batch[1]) == 2 # label and name assert "label" in batch[1] assert all([k in possible_labels for k in batch[1]["label"]]) assert "name" in batch[1] assert all( [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] ) # use the code below to view generated images # from torchvision.transforms.functional import to_pil_image # to_pil_image(batch[0][0]).show() # __import__("pdb").set_trace() @pytest.fixture def database_checkers(): return DatabaseCheckers