Skip to content
Snippets Groups Projects
Commit bd6da859 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Improve access to datadir.montgomery rc variable throughout the test units

parent 0b18bd96
No related branches found
No related tags found
No related merge requests found
Pipeline #70399 failed
{ {
"clapp": { "clapp": {
"versions": { "versions": {
"stable": "https://www.idiap.ch/software/biosignal/docs/software/clapp/stable/sphinx/", "stable": "https://clapp.readthedocs.io/en/stable/",
"latest": "https://www.idiap.ch/software/biosignal/docs/software/clapp/main/sphinx/" "latest": "https://clapp.readthedocs.io/en/latest/"
}, },
"sources": {} "sources": {}
} }
......
...@@ -34,12 +34,12 @@ _protocols = [ ...@@ -34,12 +34,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
] ]
_root_path = load_rc().get("datadir.hivtb", os.path.realpath(os.curdir)) _datadir = load_rc().get("datadir.hivtb", os.path.realpath(os.curdir))
def _raw_data_loader(sample): def _raw_data_loader(sample):
return dict( return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])), data=load_pil_baw(os.path.join(_datadir, sample["data"])),
label=sample["label"], label=sample["label"],
) )
......
...@@ -34,12 +34,12 @@ _protocols = [ ...@@ -34,12 +34,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
] ]
_root_path = load_rc().get("datadir.indian", os.path.realpath(os.curdir)) _datadir = load_rc().get("datadir.indian", os.path.realpath(os.curdir))
def _raw_data_loader(sample): def _raw_data_loader(sample):
return dict( return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])), data=load_pil_baw(os.path.join(_datadir, sample["data"])),
label=sample["label"], label=sample["label"],
) )
......
...@@ -40,25 +40,19 @@ _protocols = [ ...@@ -40,25 +40,19 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
] ]
_root_path = None _datadir = load_rc().get("datadir.montgomery", os.path.realpath(os.curdir))
def _raw_data_loader(sample): def _raw_data_loader(sample):
# hack to allow tests to change "datadir.montgomery"
global _root_path
_root_path = _root_path or load_rc().get(
"datadir.montgomery", os.path.realpath(os.curdir)
)
return dict( return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])), # type: ignore data=load_pil_baw(os.path.join(_datadir, sample["data"])), # type: ignore
label=sample["label"], label=sample["label"],
) )
def _loader(context, sample): def _loader(context, sample):
# "context" is ignored in this case - database is homogeneous # "context" is ignored in this case - database is homogeneous
# we returned delayed samples to avoid loading all images at once # we return delayed samples to avoid loading all images at once
return make_delayed(sample, _raw_data_loader) return make_delayed(sample, _raw_data_loader)
......
...@@ -38,7 +38,7 @@ _protocols = [ ...@@ -38,7 +38,7 @@ _protocols = [
importlib.resources.files(__name__).joinpath("cardiomegaly.json.bz2"), importlib.resources.files(__name__).joinpath("cardiomegaly.json.bz2"),
] ]
_root_path = load_rc().get("datadir.nih_cxr14_re", os.path.realpath(os.curdir)) _datadir = load_rc().get("datadir.nih_cxr14_re", os.path.realpath(os.curdir))
_idiap_folders = load_rc().get("nih_cxr14_re.idiap_folder_structure", False) _idiap_folders = load_rc().get("nih_cxr14_re.idiap_folder_structure", False)
...@@ -51,7 +51,7 @@ def _raw_data_loader(sample): ...@@ -51,7 +51,7 @@ def _raw_data_loader(sample):
return dict( return dict(
data=load_pil_rgb( data=load_pil_rgb(
os.path.join( os.path.join(
_root_path, _datadir,
os.path.dirname(sample["data"]), os.path.dirname(sample["data"]),
basename[:5], basename[:5],
basename, basename,
...@@ -61,7 +61,7 @@ def _raw_data_loader(sample): ...@@ -61,7 +61,7 @@ def _raw_data_loader(sample):
) )
else: else:
return dict( return dict(
data=load_pil_rgb(os.path.join(_root_path, sample["data"])), data=load_pil_rgb(os.path.join(_datadir, sample["data"])),
label=sample["label"], label=sample["label"],
) )
......
...@@ -236,12 +236,12 @@ _protocols = [ ...@@ -236,12 +236,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("cardiomegaly_idiap.json.bz2"), importlib.resources.files(__name__).joinpath("cardiomegaly_idiap.json.bz2"),
] ]
_root_path = load_rc().get("datadir.padchest", os.path.realpath(os.curdir)) _datadir = load_rc().get("datadir.padchest", os.path.realpath(os.curdir))
def _raw_data_loader(sample): def _raw_data_loader(sample):
return dict( return dict(
data=load_pil(os.path.join(_root_path, sample["data"])), data=load_pil(os.path.join(_datadir, sample["data"])),
label=sample["label"], label=sample["label"],
) )
......
...@@ -41,12 +41,12 @@ _protocols = [ ...@@ -41,12 +41,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
] ]
_root_path = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
def _raw_data_loader(sample): def _raw_data_loader(sample):
return dict( return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])), data=load_pil_baw(os.path.join(_datadir, sample["data"])),
label=sample["label"], label=sample["label"],
) )
......
...@@ -34,12 +34,12 @@ _protocols = [ ...@@ -34,12 +34,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
] ]
_root_path = load_rc().get("datadir.tbpoc", os.path.realpath(os.curdir)) _datadir = load_rc().get("datadir.tbpoc", os.path.realpath(os.curdir))
def _raw_data_loader(sample): def _raw_data_loader(sample):
return dict( return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])), data=load_pil_baw(os.path.join(_datadir, sample["data"])),
label=sample["label"], label=sample["label"],
) )
......
...@@ -2,9 +2,12 @@ ...@@ -2,9 +2,12 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import os
import pathlib import pathlib
import typing
import pytest import pytest
import tomli_w
@pytest.fixture @pytest.fixture
...@@ -67,30 +70,56 @@ def temporary_basedir(tmp_path_factory): ...@@ -67,30 +70,56 @@ def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli") return tmp_path_factory.mktemp("test-cli")
@pytest.fixture(scope="session") @pytest.fixture(scope="session", autouse=True)
def montgomery_datadir(tmp_path_factory) -> pathlib.Path: def ensure_montgomery(
from ptbench.utils.rc import load_rc tmp_path_factory,
) -> typing.Generator[None, None, None]:
database_dir = load_rc().get("datadir.montgomery") """A pytest fixture that ensures that datadir.montgomery is always
if database_dir is not None: available."""
return pathlib.Path(database_dir)
# else, we must extract the LFS component
archive = (
pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip"
)
assert archive.exists(), (
f"Neither datadir.montgomery is set on the global configuration, "
f"(typically ~/.config/ptbench.toml), or it is possible to detect "
f"the presence of {archive}' (did you git submodule init --update "
f"this submodule?)"
)
database_dir = tmp_path_factory.mktemp("montgomery_datadir")
import tempfile
import zipfile import zipfile
with zipfile.ZipFile(archive) as zf: from ptbench.utils.rc import load_rc
zf.extractall(database_dir)
rc = load_rc()
return database_dir database_dir = rc.get("datadir.montgomery")
if database_dir is not None:
# if the user downloaded it, use that copy
yield
else:
# else, we must extract the LFS component (we are likely on the CI)
archive = (
pathlib.Path(__file__).parents[0]
/ "data"
/ "lfs"
/ "test-database.zip"
)
assert archive.exists(), (
f"Neither datadir.montgomery is set on the global configuration, "
f"(typically ~/.config/ptbench.toml), or it is possible to detect "
f"the presence of {archive}' (did you git submodule init --update "
f"this submodule?)"
)
database_dir = tmp_path_factory.mktemp("montgomery_datadir")
rc.setdefault("datadir.montgomery", str(database_dir))
with zipfile.ZipFile(archive) as zf:
zf.extractall(database_dir)
with tempfile.TemporaryDirectory() as tmpdir:
config_filename = "ptbench.toml"
with open(os.path.join(tmpdir, config_filename), "wb") as f:
tomli_w.dump(rc.data, f)
f.flush()
old_config_home = os.environ.get("XDG_CONFIG_HOME")
os.environ["XDG_CONFIG_HOME"] = tmpdir
yield
if old_config_home is None:
del os.environ["XDG_CONFIG_HOME"]
else:
os.environ["XDG_CONFIG_HOME"] = old_config_home
This diff is collapsed.
...@@ -2,13 +2,8 @@ ...@@ -2,13 +2,8 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import contextlib
import os
import tempfile
import numpy as np import numpy as np
import pytest import pytest
import tomli_w
import torch import torch
from torch.utils.data import ConcatDataset from torch.utils.data import ConcatDataset
...@@ -21,23 +16,6 @@ from ptbench.configs.datasets import get_positive_weights, get_samples_weights ...@@ -21,23 +16,6 @@ from ptbench.configs.datasets import get_positive_weights, get_samples_weights
N = 10 N = 10
@contextlib.contextmanager
def rc_context(**new_config):
with tempfile.TemporaryDirectory() as tmpdir:
config_filename = "ptbench.toml"
with open(os.path.join(tmpdir, config_filename), "wb") as f:
tomli_w.dump(new_config, f)
f.flush()
old_config_home = os.environ.get("XDG_CONFIG_HOME")
os.environ["XDG_CONFIG_HOME"] = tmpdir
yield
if old_config_home is None:
del os.environ["XDG_CONFIG_HOME"]
else:
os.environ["XDG_CONFIG_HOME"] = old_config_home
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_montgomery(): def test_montgomery():
def _check_subset(samples, size): def _check_subset(samples, size):
assert len(samples) == size assert len(samples) == size
...@@ -60,20 +38,15 @@ def test_montgomery(): ...@@ -60,20 +38,15 @@ def test_montgomery():
_check_subset(dataset["test"], 28) _check_subset(dataset["test"], 28)
def test_get_samples_weights(montgomery_datadir): def test_get_samples_weights():
# Temporarily modify Montgomery datadir from ptbench.configs.datasets.montgomery.default import dataset
new_value = {"datadir.montgomery": str(montgomery_datadir)}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
train_samples_weights = get_samples_weights( train_samples_weights = get_samples_weights(dataset["__train__"]).numpy()
dataset["__train__"]
).numpy()
unique, counts = np.unique(train_samples_weights, return_counts=True) unique, counts = np.unique(train_samples_weights, return_counts=True)
np.testing.assert_equal(counts, np.array([51, 37])) np.testing.assert_equal(counts, np.array([51, 37]))
np.testing.assert_equal(unique, np.array(1 / counts, dtype=np.float32)) np.testing.assert_equal(unique, np.array(1 / counts, dtype=np.float32))
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re") @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
...@@ -87,22 +60,17 @@ def test_get_samples_weights_multi(): ...@@ -87,22 +60,17 @@ def test_get_samples_weights_multi():
) )
def test_get_samples_weights_concat(montgomery_datadir): def test_get_samples_weights_concat():
# Temporarily modify Montgomery datadir from ptbench.configs.datasets.montgomery.default import dataset
new_value = {"datadir.montgomery": str(montgomery_datadir)}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
train_dataset = ConcatDataset( train_dataset = ConcatDataset((dataset["__train__"], dataset["__train__"]))
(dataset["__train__"], dataset["__train__"])
)
train_samples_weights = get_samples_weights(train_dataset).numpy() train_samples_weights = get_samples_weights(train_dataset).numpy()
unique, counts = np.unique(train_samples_weights, return_counts=True) unique, counts = np.unique(train_samples_weights, return_counts=True)
np.testing.assert_equal(counts, np.array([102, 74])) np.testing.assert_equal(counts, np.array([102, 74]))
np.testing.assert_equal(unique, np.array(2 / counts, dtype=np.float32)) np.testing.assert_equal(unique, np.array(2 / counts, dtype=np.float32))
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re") @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
...@@ -127,19 +95,14 @@ def test_get_samples_weights_multi_concat(): ...@@ -127,19 +95,14 @@ def test_get_samples_weights_multi_concat():
np.testing.assert_equal(train_samples_weights, ref_samples_weights) np.testing.assert_equal(train_samples_weights, ref_samples_weights)
def test_get_positive_weights(montgomery_datadir): def test_get_positive_weights():
# Temporarily modify Montgomery datadir from ptbench.configs.datasets.montgomery.default import dataset
new_value = {"datadir.montgomery": str(montgomery_datadir)}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
train_positive_weights = get_positive_weights( train_positive_weights = get_positive_weights(dataset["__train__"]).numpy()
dataset["__train__"]
).numpy()
np.testing.assert_equal( np.testing.assert_equal(
train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32) train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32)
) )
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re") @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
...@@ -204,21 +167,16 @@ def test_get_positive_weights_multi(): ...@@ -204,21 +167,16 @@ def test_get_positive_weights_multi():
) )
def test_get_positive_weights_concat(montgomery_datadir): def test_get_positive_weights_concat():
# Temporarily modify Montgomery datadir from ptbench.configs.datasets.montgomery.default import dataset
new_value = {"datadir.montgomery": str(montgomery_datadir)}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
train_dataset = ConcatDataset( train_dataset = ConcatDataset((dataset["__train__"], dataset["__train__"]))
(dataset["__train__"], dataset["__train__"])
)
train_positive_weights = get_positive_weights(train_dataset).numpy() train_positive_weights = get_positive_weights(train_dataset).numpy()
np.testing.assert_equal( np.testing.assert_equal(
train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32) train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32)
) )
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re") @pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment