diff --git a/.gitignore b/.gitignore index 55e6cc84408f4a6e466147637b1d368506eef40b..d3edf61edcee2368bc7a7b18261da8a35c1bcb56 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ *.pyc *.egg-info .nfs* -.coverage +.coverage* *.DS_Store .envrc coverage.xml diff --git a/src/ptbench/data/montgomery/__init__.py b/src/ptbench/data/montgomery/__init__.py index 4334180592ffca7cbbcd299e901d472418e5cebd..fd65f24dbde588d9958b90a488ea56cc4111496e 100644 --- a/src/ptbench/data/montgomery/__init__.py +++ b/src/ptbench/data/montgomery/__init__.py @@ -40,12 +40,18 @@ _protocols = [ importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), ] -_root_path = load_rc().get("datadir.montgomery", os.path.realpath(os.curdir)) +_root_path = None def _raw_data_loader(sample): + # hack to allow tests to change "datadir.montgomery" + global _root_path + _root_path = _root_path or load_rc().get( + "datadir.montgomery", os.path.realpath(os.curdir) + ) + return dict( - data=load_pil_baw(os.path.join(_root_path, sample["data"])), + data=load_pil_baw(os.path.join(_root_path, sample["data"])), # type: ignore label=sample["label"], )