# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import os import pathlib import tempfile import zipfile import pytest import tomli_w @pytest.fixture def datadir(request) -> pathlib.Path: """Returns the directory in which the test is sitting.""" return pathlib.Path(request.module.__file__).parents[0] / "data" def pytest_configure(config): """This function is run once for pytest setup.""" config.addinivalue_line( "markers", "skip_if_rc_var_not_set(name): this mark skips the test if a certain " "~/.config/ptbench.toml variable is not set", ) config.addinivalue_line("markers", "slow: this mark indicates slow tests") def pytest_runtest_setup(item): """This function is run for every test candidate in this directory. The test is run if this function returns ``None``. To skip a test, call ``pytest.skip()``, specifying a reason. """ from ptbench.utils.rc import load_rc rc = load_rc() # iterates over all markers for the item being examined, get the first # argument and accumulate these names rc_names = [ mark.args[0] for mark in item.iter_markers(name="skip_if_rc_var_not_set") ] # checks all names mentioned are set in ~/.config/ptbench.toml, otherwise, # skip the test if rc_names: missing = [k for k in rc_names if rc.get(k) is None] if any(missing): pytest.skip( f"Test skipped because {', '.join(missing)} is **not** " f"set in ~/.config/ptbench.toml" ) def rc_variable_set(name): from ptbench.utils.rc import load_rc rc = load_rc() pytest.mark.skipif( name not in rc, reason=f"RC variable '{name}' is not set", ) @pytest.fixture(scope="session") def temporary_basedir(tmp_path_factory): return tmp_path_factory.mktemp("test-cli") def pytest_sessionstart(session: pytest.Session) -> None: """Presets the session start to ensure the Montgomery dataset is always available.""" from ptbench.utils.rc import load_rc rc = load_rc() database_dir = rc.get("datadir.montgomery") if database_dir is not None: # if the user downloaded it, use that copy return # else, we must extract the LFS component (we are likely on the CI) archive = ( pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip" ) assert archive.exists(), ( f"Neither datadir.montgomery is set on the global configuration, " f"(typically ~/.config/ptbench.toml), or it is possible to detect " f"the presence of {archive}' (did you git submodule init --update " f"this submodule?)" ) montgomery_tempdir = tempfile.TemporaryDirectory() rc.setdefault("datadir.montgomery", montgomery_tempdir.name) with zipfile.ZipFile(archive) as zf: zf.extractall(montgomery_tempdir.name) config_filename = "ptbench.toml" with open( os.path.join(montgomery_tempdir.name, config_filename), "wb" ) as f: tomli_w.dump(rc.data, f) f.flush() os.environ["XDG_CONFIG_HOME"] = montgomery_tempdir.name # stash the newly created temporary directory so we can erase it when the key = pytest.StashKey[tempfile.TemporaryDirectory]() session.stash[key] = montgomery_tempdir