Skip to content
Snippets Groups Projects
Commit bd6da859 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Improve access to datadir.montgomery rc variable throughout the test units

parent 0b18bd96
No related branches found
No related tags found
No related merge requests found
Pipeline #70399 failed
{
"clapp": {
"versions": {
"stable": "https://www.idiap.ch/software/biosignal/docs/software/clapp/stable/sphinx/",
"latest": "https://www.idiap.ch/software/biosignal/docs/software/clapp/main/sphinx/"
"stable": "https://clapp.readthedocs.io/en/stable/",
"latest": "https://clapp.readthedocs.io/en/latest/"
},
"sources": {}
}
......
......@@ -34,12 +34,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_root_path = load_rc().get("datadir.hivtb", os.path.realpath(os.curdir))
_datadir = load_rc().get("datadir.hivtb", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])),
data=load_pil_baw(os.path.join(_datadir, sample["data"])),
label=sample["label"],
)
......
......@@ -34,12 +34,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_root_path = load_rc().get("datadir.indian", os.path.realpath(os.curdir))
_datadir = load_rc().get("datadir.indian", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])),
data=load_pil_baw(os.path.join(_datadir, sample["data"])),
label=sample["label"],
)
......
......@@ -40,25 +40,19 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_root_path = None
_datadir = load_rc().get("datadir.montgomery", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
# hack to allow tests to change "datadir.montgomery"
global _root_path
_root_path = _root_path or load_rc().get(
"datadir.montgomery", os.path.realpath(os.curdir)
)
return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])), # type: ignore
data=load_pil_baw(os.path.join(_datadir, sample["data"])), # type: ignore
label=sample["label"],
)
def _loader(context, sample):
# "context" is ignored in this case - database is homogeneous
# we returned delayed samples to avoid loading all images at once
# we return delayed samples to avoid loading all images at once
return make_delayed(sample, _raw_data_loader)
......
......@@ -38,7 +38,7 @@ _protocols = [
importlib.resources.files(__name__).joinpath("cardiomegaly.json.bz2"),
]
_root_path = load_rc().get("datadir.nih_cxr14_re", os.path.realpath(os.curdir))
_datadir = load_rc().get("datadir.nih_cxr14_re", os.path.realpath(os.curdir))
_idiap_folders = load_rc().get("nih_cxr14_re.idiap_folder_structure", False)
......@@ -51,7 +51,7 @@ def _raw_data_loader(sample):
return dict(
data=load_pil_rgb(
os.path.join(
_root_path,
_datadir,
os.path.dirname(sample["data"]),
basename[:5],
basename,
......@@ -61,7 +61,7 @@ def _raw_data_loader(sample):
)
else:
return dict(
data=load_pil_rgb(os.path.join(_root_path, sample["data"])),
data=load_pil_rgb(os.path.join(_datadir, sample["data"])),
label=sample["label"],
)
......
......@@ -236,12 +236,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("cardiomegaly_idiap.json.bz2"),
]
_root_path = load_rc().get("datadir.padchest", os.path.realpath(os.curdir))
_datadir = load_rc().get("datadir.padchest", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
return dict(
data=load_pil(os.path.join(_root_path, sample["data"])),
data=load_pil(os.path.join(_datadir, sample["data"])),
label=sample["label"],
)
......
......@@ -41,12 +41,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_root_path = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])),
data=load_pil_baw(os.path.join(_datadir, sample["data"])),
label=sample["label"],
)
......
......@@ -34,12 +34,12 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_root_path = load_rc().get("datadir.tbpoc", os.path.realpath(os.curdir))
_datadir = load_rc().get("datadir.tbpoc", os.path.realpath(os.curdir))
def _raw_data_loader(sample):
return dict(
data=load_pil_baw(os.path.join(_root_path, sample["data"])),
data=load_pil_baw(os.path.join(_datadir, sample["data"])),
label=sample["label"],
)
......
......@@ -2,9 +2,12 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
import pathlib
import typing
import pytest
import tomli_w
@pytest.fixture
......@@ -67,30 +70,56 @@ def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli")
@pytest.fixture(scope="session")
def montgomery_datadir(tmp_path_factory) -> pathlib.Path:
from ptbench.utils.rc import load_rc
database_dir = load_rc().get("datadir.montgomery")
if database_dir is not None:
return pathlib.Path(database_dir)
# else, we must extract the LFS component
archive = (
pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip"
)
assert archive.exists(), (
f"Neither datadir.montgomery is set on the global configuration, "
f"(typically ~/.config/ptbench.toml), or it is possible to detect "
f"the presence of {archive}' (did you git submodule init --update "
f"this submodule?)"
)
database_dir = tmp_path_factory.mktemp("montgomery_datadir")
@pytest.fixture(scope="session", autouse=True)
def ensure_montgomery(
tmp_path_factory,
) -> typing.Generator[None, None, None]:
"""A pytest fixture that ensures that datadir.montgomery is always
available."""
import tempfile
import zipfile
with zipfile.ZipFile(archive) as zf:
zf.extractall(database_dir)
from ptbench.utils.rc import load_rc
rc = load_rc()
return database_dir
database_dir = rc.get("datadir.montgomery")
if database_dir is not None:
# if the user downloaded it, use that copy
yield
else:
# else, we must extract the LFS component (we are likely on the CI)
archive = (
pathlib.Path(__file__).parents[0]
/ "data"
/ "lfs"
/ "test-database.zip"
)
assert archive.exists(), (
f"Neither datadir.montgomery is set on the global configuration, "
f"(typically ~/.config/ptbench.toml), or it is possible to detect "
f"the presence of {archive}' (did you git submodule init --update "
f"this submodule?)"
)
database_dir = tmp_path_factory.mktemp("montgomery_datadir")
rc.setdefault("datadir.montgomery", str(database_dir))
with zipfile.ZipFile(archive) as zf:
zf.extractall(database_dir)
with tempfile.TemporaryDirectory() as tmpdir:
config_filename = "ptbench.toml"
with open(os.path.join(tmpdir, config_filename), "wb") as f:
tomli_w.dump(rc.data, f)
f.flush()
old_config_home = os.environ.get("XDG_CONFIG_HOME")
os.environ["XDG_CONFIG_HOME"] = tmpdir
yield
if old_config_home is None:
del os.environ["XDG_CONFIG_HOME"]
else:
os.environ["XDG_CONFIG_HOME"] = old_config_home
This diff is collapsed.
......@@ -2,13 +2,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import contextlib
import os
import tempfile
import numpy as np
import pytest
import tomli_w
import torch
from torch.utils.data import ConcatDataset
......@@ -21,23 +16,6 @@ from ptbench.configs.datasets import get_positive_weights, get_samples_weights
N = 10
@contextlib.contextmanager
def rc_context(**new_config):
with tempfile.TemporaryDirectory() as tmpdir:
config_filename = "ptbench.toml"
with open(os.path.join(tmpdir, config_filename), "wb") as f:
tomli_w.dump(new_config, f)
f.flush()
old_config_home = os.environ.get("XDG_CONFIG_HOME")
os.environ["XDG_CONFIG_HOME"] = tmpdir
yield
if old_config_home is None:
del os.environ["XDG_CONFIG_HOME"]
else:
os.environ["XDG_CONFIG_HOME"] = old_config_home
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_montgomery():
def _check_subset(samples, size):
assert len(samples) == size
......@@ -60,20 +38,15 @@ def test_montgomery():
_check_subset(dataset["test"], 28)
def test_get_samples_weights(montgomery_datadir):
# Temporarily modify Montgomery datadir
new_value = {"datadir.montgomery": str(montgomery_datadir)}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
def test_get_samples_weights():
from ptbench.configs.datasets.montgomery.default import dataset
train_samples_weights = get_samples_weights(
dataset["__train__"]
).numpy()
train_samples_weights = get_samples_weights(dataset["__train__"]).numpy()
unique, counts = np.unique(train_samples_weights, return_counts=True)
unique, counts = np.unique(train_samples_weights, return_counts=True)
np.testing.assert_equal(counts, np.array([51, 37]))
np.testing.assert_equal(unique, np.array(1 / counts, dtype=np.float32))
np.testing.assert_equal(counts, np.array([51, 37]))
np.testing.assert_equal(unique, np.array(1 / counts, dtype=np.float32))
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
......@@ -87,22 +60,17 @@ def test_get_samples_weights_multi():
)
def test_get_samples_weights_concat(montgomery_datadir):
# Temporarily modify Montgomery datadir
new_value = {"datadir.montgomery": str(montgomery_datadir)}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
def test_get_samples_weights_concat():
from ptbench.configs.datasets.montgomery.default import dataset
train_dataset = ConcatDataset(
(dataset["__train__"], dataset["__train__"])
)
train_dataset = ConcatDataset((dataset["__train__"], dataset["__train__"]))
train_samples_weights = get_samples_weights(train_dataset).numpy()
train_samples_weights = get_samples_weights(train_dataset).numpy()
unique, counts = np.unique(train_samples_weights, return_counts=True)
unique, counts = np.unique(train_samples_weights, return_counts=True)
np.testing.assert_equal(counts, np.array([102, 74]))
np.testing.assert_equal(unique, np.array(2 / counts, dtype=np.float32))
np.testing.assert_equal(counts, np.array([102, 74]))
np.testing.assert_equal(unique, np.array(2 / counts, dtype=np.float32))
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
......@@ -127,19 +95,14 @@ def test_get_samples_weights_multi_concat():
np.testing.assert_equal(train_samples_weights, ref_samples_weights)
def test_get_positive_weights(montgomery_datadir):
# Temporarily modify Montgomery datadir
new_value = {"datadir.montgomery": str(montgomery_datadir)}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
def test_get_positive_weights():
from ptbench.configs.datasets.montgomery.default import dataset
train_positive_weights = get_positive_weights(
dataset["__train__"]
).numpy()
train_positive_weights = get_positive_weights(dataset["__train__"]).numpy()
np.testing.assert_equal(
train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32)
)
np.testing.assert_equal(
train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32)
)
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
......@@ -204,21 +167,16 @@ def test_get_positive_weights_multi():
)
def test_get_positive_weights_concat(montgomery_datadir):
# Temporarily modify Montgomery datadir
new_value = {"datadir.montgomery": str(montgomery_datadir)}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
def test_get_positive_weights_concat():
from ptbench.configs.datasets.montgomery.default import dataset
train_dataset = ConcatDataset(
(dataset["__train__"], dataset["__train__"])
)
train_dataset = ConcatDataset((dataset["__train__"], dataset["__train__"]))
train_positive_weights = get_positive_weights(train_dataset).numpy()
train_positive_weights = get_positive_weights(train_dataset).numpy()
np.testing.assert_equal(
train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32)
)
np.testing.assert_equal(
train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32)
)
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment