-
André Anjos authoredAndré Anjos authored
test_config.py 9.10 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import contextlib
import os
import tempfile
import numpy as np
import pytest
import tomli_w
import torch
from torch.utils.data import ConcatDataset
from ptbench.configs.datasets import get_positive_weights, get_samples_weights
from . import mock_dataset
# Download test data and get their location if needed
montgomery_datadir = mock_dataset()
# we only iterate over the first N elements at most - dataset loading has
# already been checked on the individual datset tests. Here, we are only
# testing for the extra tools wrapping the dataset
N = 10
@contextlib.contextmanager
def rc_context(**new_config):
with tempfile.TemporaryDirectory() as tmpdir:
config_filename = "ptbench.toml"
with open(os.path.join(tmpdir, config_filename), "wb") as f:
tomli_w.dump(new_config, f)
f.flush()
old_config_home = os.environ.get("XDG_CONFIG_HOME")
os.environ["XDG_CONFIG_HOME"] = tmpdir
yield
if old_config_home is None:
del os.environ["XDG_CONFIG_HOME"]
else:
os.environ["XDG_CONFIG_HOME"] = old_config_home
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_montgomery():
def _check_subset(samples, size):
assert len(samples) == size
for s in samples[:N]:
assert len(s) == 3
assert isinstance(s[0], str) # key
assert s[1].shape == (1, 512, 512) # planes, height, width
assert s[1].dtype == torch.float32
assert isinstance(s[2], int) # label
assert s[1].max() <= 1.0
assert s[1].min() >= 0.0
from ptbench.configs.datasets.montgomery.default import dataset
assert len(dataset) == 5
_check_subset(dataset["__train__"], 88)
_check_subset(dataset["__valid__"], 22)
_check_subset(dataset["train"], 88)
_check_subset(dataset["validation"], 22)
_check_subset(dataset["test"], 28)
def test_get_samples_weights():
# Temporarily modify Montgomery datadir
new_value = {"datadir.montgomery": montgomery_datadir}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
train_samples_weights = get_samples_weights(
dataset["__train__"]
).numpy()
unique, counts = np.unique(train_samples_weights, return_counts=True)
np.testing.assert_equal(counts, np.array([51, 37]))
np.testing.assert_equal(unique, np.array(1 / counts, dtype=np.float32))
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
def test_get_samples_weights_multi():
from ptbench.configs.datasets.nih_cxr14_re.default import dataset
train_samples_weights = get_samples_weights(dataset["__train__"]).numpy()
np.testing.assert_equal(
train_samples_weights, np.ones(len(dataset["__train__"]))
)
def test_get_samples_weights_concat():
# Temporarily modify Montgomery datadir
new_value = {"datadir.montgomery": montgomery_datadir}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
train_dataset = ConcatDataset(
(dataset["__train__"], dataset["__train__"])
)
train_samples_weights = get_samples_weights(train_dataset).numpy()
unique, counts = np.unique(train_samples_weights, return_counts=True)
np.testing.assert_equal(counts, np.array([102, 74]))
np.testing.assert_equal(unique, np.array(2 / counts, dtype=np.float32))
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
def test_get_samples_weights_multi_concat():
from ptbench.configs.datasets.nih_cxr14_re.default import dataset
train_dataset = ConcatDataset((dataset["__train__"], dataset["__train__"]))
train_samples_weights = get_samples_weights(train_dataset).numpy()
ref_samples_weights = np.concatenate(
(
torch.full(
(len(dataset["__train__"]),), 1.0 / len(dataset["__train__"])
),
torch.full(
(len(dataset["__train__"]),), 1.0 / len(dataset["__train__"])
),
)
)
np.testing.assert_equal(train_samples_weights, ref_samples_weights)
def test_get_positive_weights():
# Temporarily modify Montgomery datadir
new_value = {"datadir.montgomery": montgomery_datadir}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
train_positive_weights = get_positive_weights(
dataset["__train__"]
).numpy()
np.testing.assert_equal(
train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32)
)
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
def test_get_positive_weights_multi():
from ptbench.configs.datasets.nih_cxr14_re.default import dataset
train_positive_weights = get_positive_weights(dataset["__train__"]).numpy()
valid_positive_weights = get_positive_weights(dataset["__valid__"]).numpy()
assert torch.all(
torch.eq(
torch.FloatTensor(np.around(train_positive_weights, 4)),
torch.FloatTensor(
np.around(
[
0.9195434,
0.9462068,
0.8070095,
0.94879204,
0.767055,
0.8944615,
0.88212335,
0.8227136,
0.8943905,
0.8864118,
0.90026057,
0.8888551,
0.884739,
0.84540284,
],
4,
)
),
)
)
assert torch.all(
torch.eq(
torch.FloatTensor(np.around(valid_positive_weights, 4)),
torch.FloatTensor(
np.around(
[
0.9366929,
0.9535433,
0.79543304,
0.9530709,
0.74834645,
0.88708663,
0.86661416,
0.81496066,
0.89480317,
0.8888189,
0.8933858,
0.89795274,
0.87181103,
0.8266142,
],
4,
)
),
)
)
def test_get_positive_weights_concat():
# Temporarily modify Montgomery datadir
new_value = {"datadir.montgomery": montgomery_datadir}
with rc_context(**new_value):
from ptbench.configs.datasets.montgomery.default import dataset
train_dataset = ConcatDataset(
(dataset["__train__"], dataset["__train__"])
)
train_positive_weights = get_positive_weights(train_dataset).numpy()
np.testing.assert_equal(
train_positive_weights, np.array([51.0 / 37.0], dtype=np.float32)
)
@pytest.mark.skip_if_rc_var_not_set("datadir.nih_cxr14_re")
def test_get_positive_weights_multi_concat():
from ptbench.configs.datasets.nih_cxr14_re.default import dataset
train_dataset = ConcatDataset((dataset["__train__"], dataset["__train__"]))
valid_dataset = ConcatDataset((dataset["__valid__"], dataset["__valid__"]))
train_positive_weights = get_positive_weights(train_dataset).numpy()
valid_positive_weights = get_positive_weights(valid_dataset).numpy()
assert torch.all(
torch.eq(
torch.FloatTensor(np.around(train_positive_weights, 4)),
torch.FloatTensor(
np.around(
[
0.9195434,
0.9462068,
0.8070095,
0.94879204,
0.767055,
0.8944615,
0.88212335,
0.8227136,
0.8943905,
0.8864118,
0.90026057,
0.8888551,
0.884739,
0.84540284,
],
4,
)
),
)
)
assert torch.all(
torch.eq(
torch.FloatTensor(np.around(valid_positive_weights, 4)),
torch.FloatTensor(
np.around(
[
0.9366929,
0.9535433,
0.79543304,
0.9530709,
0.74834645,
0.88708663,
0.86661416,
0.81496066,
0.89480317,
0.8888189,
0.8933858,
0.89795274,
0.87181103,
0.8266142,
],
4,
)
),
)
)