Skip to content
Snippets Groups Projects
Commit b4b8c620 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[segmentation.tests] Add more tests

parent 229dd23c
No related branches found
No related tags found
1 merge request!46Create common library
import argparse
import pathlib
import h5py
import torch
import torchvision.transforms.functional as tf
def save_images(tensors_dict, output_dir):
for key, value in tensors_dict.items():
output_file = output_dir / pathlib.Path(key + ".png")
tf.to_pil_image(value).save(output_file)
print(f"Saved file {output_file}")
def extract_images_from_hdf5(hdf5_file):
tensors_dict = {}
with h5py.File(hdf5_file, "r") as f:
tensors_dict["image"] = torch.from_numpy(f.get("img")[:])
tensors_dict["target"] = torch.from_numpy(f.get("target")[:])
tensors_dict["mask"] = torch.from_numpy(f.get("mask")[:])
return tensors_dict
def get_hdf5_files(directory, recursive=False) -> list[pathlib.Path]:
return directory.glob("**/*.hdf5", recursive=recursive)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir",
type=pathlib.Path,
help="Directory in which hdf5 files are located.",
)
parser.add_argument(
"--recursive",
"-r",
action="store_true",
help="Set to true to search recursively in the input directory.",
)
args = parser.parse_args()
hdf5_files = get_hdf5_files(args.input_dir, recursive=args.recursive)
for hdf5_file in hdf5_files:
tensors_dict = extract_images_from_hdf5(hdf5_file)
save_dir = pathlib.Path(hdf5_file).with_suffix("").with_suffix("")
save_dir.mkdir(parents=True, exist_ok=True)
save_images(tensors_dict, save_dir)
if __name__ == "__main__":
main()
......@@ -389,7 +389,7 @@ def test_evaluate_pasa_montgomery(temporary_basedir):
"-vv",
"montgomery",
f"--predictions={predictions_file}",
f"--output-folder={str(temporary_basedir)}",
f"--output-folder={temporary_basedir}",
"--threshold=test",
],
)
......
......@@ -12,9 +12,8 @@ from . import (
# compare,
config,
database,
evaluate,
predict,
# evaluate,
# experiment,
# mkmask,
# significance,
train,
......@@ -34,12 +33,11 @@ def segmentation():
# segmentation.add_command(compare.compare)
segmentation.add_command(config.config)
segmentation.add_command(database.database)
# segmentation.add_command(evaluate.evaluate)
# segmentation.add_command(experiment.experiment)
# segmentation.add_command(mkmask.mkmask)
# segmentation.add_command(significance.significance)
segmentation.add_command(train.train)
segmentation.add_command(predict.predict)
segmentation.add_command(evaluate.evaluate)
segmentation.add_command(
importlib.import_module(
"mednet.libs.common.scripts.train_analysis",
......
......@@ -19,7 +19,7 @@ from mednet.libs.segmentation.engine.evaluator import run
@click.command(
entry_point_group="deepdraw.config",
entry_point_group="mednet.libs.segmentation.config",
cls=ConfigCommand,
epilog="""Examples:
......@@ -63,19 +63,13 @@ from mednet.libs.segmentation.engine.evaluator import run
"-o",
help="Directory in which to store results (created if does not exist)",
required=True,
type=click.Path(
file_okay=False,
dir_okay=True,
writable=True,
path_type=pathlib.Path,
),
default="results",
type=click.Path(),
cls=ResourceOption,
)
@click.option(
"--dataset",
"-d",
help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
"to be used for evaluation purposes, possibly including all pre-processing "
"pipelines required or, optionally, a dictionary mapping string keys to "
"torch.utils.data.dataset.Dataset instances. All keys that do not start "
"with an underscore (_) will be processed.",
required=True,
cls=ResourceOption,
)
@click.option(
......@@ -144,17 +138,15 @@ from mednet.libs.segmentation.engine.evaluator import run
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
@click.pass_context
def evaluate(
ctx,
predictions: pathlib.Path,
output_folder: pathlib.Path,
threshold: str | float,
second_annotator,
# second_annotator,
overlayed,
steps,
parallel,
**kwargs,
**_, # ignored
): # numpydoc ignore=PR01
"""Evaluate predictions (from a model) on a segmentation task."""
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for our CLI applications."""
import contextlib
import re
import pytest
from click.testing import CliRunner
@contextlib.contextmanager
def stdout_logging():
# copy logging messages to std out
import io
import logging
buf = io.StringIO()
ch = logging.StreamHandler(buf)
ch.setFormatter(logging.Formatter("%(message)s"))
ch.setLevel(logging.INFO)
logger = logging.getLogger("mednet")
logger.addHandler(ch)
yield buf
logger.removeHandler(ch)
def _assert_exit_0(result):
assert (
result.exit_code == 0
), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
def _check_help(entry_point):
runner = CliRunner()
result = runner.invoke(entry_point, ["--help"])
_assert_exit_0(result)
assert result.output.startswith("Usage:")
def test_config_help():
from mednet.libs.segmentation.scripts.config import config
_check_help(config)
def test_config_list_help():
from mednet.libs.segmentation.scripts.config import list_
_check_help(list_)
def test_config_list():
from mednet.libs.segmentation.scripts.config import list_
runner = CliRunner()
result = runner.invoke(list_)
_assert_exit_0(result)
assert "module: mednet.libs.segmentation.config.data" in result.output
assert "module: mednet.libs.segmentation.config.models" in result.output
def test_config_list_v():
from mednet.libs.segmentation.scripts.config import list_
result = CliRunner().invoke(list_, ["--verbose"])
_assert_exit_0(result)
assert "module: mednet.libs.segmentation.config.data" in result.output
assert "module: mednet.libs.segmentation.config.models" in result.output
def test_config_describe_help():
from mednet.libs.segmentation.scripts.config import describe
_check_help(describe)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_config_describe_drive():
from mednet.libs.segmentation.scripts.config import describe
runner = CliRunner()
result = runner.invoke(describe, ["drive"])
_assert_exit_0(result)
assert (
"DRIVE dataset for Vessel Segmentation (default protocol)."
in result.output
)
def test_database_help():
from mednet.libs.segmentation.scripts.database import database
_check_help(database)
def test_datamodule_list_help():
from mednet.libs.segmentation.scripts.database import list_
_check_help(list_)
def test_datamodule_list():
from mednet.libs.segmentation.scripts.database import list_
runner = CliRunner()
result = runner.invoke(list_)
_assert_exit_0(result)
assert result.output.startswith("Available databases:")
def test_datamodule_check_help():
from mednet.libs.segmentation.scripts.database import check
_check_help(check)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_database_check():
from mednet.libs.segmentation.scripts.database import check
runner = CliRunner()
result = runner.invoke(check, ["--verbose", "--limit=1", "drive"])
_assert_exit_0(result)
def test_main_help():
from mednet.libs.segmentation.scripts.cli import segmentation
_check_help(segmentation)
def test_train_help():
from mednet.libs.segmentation.scripts.train import train
_check_help(train)
def _str_counter(substr, s):
return sum(1 for _ in re.finditer(substr, s, re.MULTILINE))
def test_predict_help():
from mednet.libs.segmentation.scripts.predict import predict
_check_help(predict)
def test_evaluate_help():
from mednet.libs.segmentation.scripts.evaluate import evaluate
_check_help(evaluate)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_train_lwnet_drive(temporary_basedir):
from mednet.libs.common.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.libs.segmentation.scripts.train import train
runner = CliRunner()
with stdout_logging() as buf:
output_folder = temporary_basedir / "results"
result = runner.invoke(
train,
[
"lwnet",
"drive",
"-vv",
"--epochs=1",
"--batch-size=1",
f"--output-folder={str(output_folder)}",
],
)
_assert_exit_0(result)
# asserts checkpoints are there, or raises FileNotFoundError
last = _get_checkpoint_from_alias(output_folder, "periodic")
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
best = _get_checkpoint_from_alias(output_folder, "best")
assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
assert (
len(list((output_folder / "logs").glob("events.out.tfevents.*")))
== 1
)
assert (output_folder / "meta.json").exists()
keywords = {
r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
r"^Training for at most 1 epochs.$": 1,
r"^Uninitialised lwnet model - computing z-norm factors from train dataloader.$": 1,
r"^Writing run metadata at.*$": 1,
r"^Dataset `train` is already setup. Not re-instantiating it.$": 3,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_train_lwnet_drive_from_checkpoint(temporary_basedir):
from mednet.libs.common.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.libs.segmentation.scripts.train import train
runner = CliRunner()
output_folder = temporary_basedir / "results" / "lwnet_checkpoint"
result0 = runner.invoke(
train,
[
"lwnet",
"drive",
"-vv",
"--epochs=1",
"--batch-size=1",
f"--output-folder={str(output_folder)}",
],
)
_assert_exit_0(result0)
# asserts checkpoints are there, or raises FileNotFoundError
last = _get_checkpoint_from_alias(output_folder, "periodic")
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
best = _get_checkpoint_from_alias(output_folder, "best")
assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
assert (output_folder / "meta.json").exists()
assert (
len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1
)
with stdout_logging() as buf:
result = runner.invoke(
train,
[
"lwnet",
"drive",
"-vv",
"--epochs=2",
"--batch-size=1",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result)
# asserts checkpoints are there, or raises FileNotFoundError
last = _get_checkpoint_from_alias(output_folder, "periodic")
assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION)
best = _get_checkpoint_from_alias(output_folder, "best")
assert (output_folder / "meta.json").exists()
assert (
len(list((output_folder / "logs").glob("events.out.tfevents.*")))
== 2
)
keywords = {
r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
r"^Training for at most 2 epochs.$": 1,
r"^Resuming from epoch 0 \(checkpoint file: .*$": 1,
r"^Writing run metadata at.*$": 1,
r"^Dataset `train` is already setup. Not re-instantiating it.$": 3,
r"^Restoring normalizer from checkpoint.$": 1,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_predict_lwnet_drive(temporary_basedir, datadir):
from mednet.libs.common.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
from mednet.libs.segmentation.scripts.predict import predict
runner = CliRunner()
with stdout_logging() as buf:
output = temporary_basedir / "predictions"
last = _get_checkpoint_from_alias(
temporary_basedir / "results",
"periodic",
)
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
result = runner.invoke(
predict,
[
"lwnet",
"drive",
"-vv",
"--batch-size=1",
f"--weight={str(last)}",
f"--output-folder={str(output)}",
],
)
_assert_exit_0(result)
assert output.exists()
keywords = {
r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 2,
r"^Loading checkpoint from .*$": 1,
r"^Restoring normalizer from checkpoint.$": 1,
r"^Running prediction on `train` split...$": 1,
r"^Running prediction on `test` split...$": 1,
r"^Predictions saved to .*$": 1,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_evaluate_lwnet_drive(temporary_basedir):
from mednet.libs.segmentation.scripts.evaluate import evaluate
runner = CliRunner()
with stdout_logging() as buf:
prediction_path = temporary_basedir / "predictions"
predictions_file = prediction_path / "predictions.json"
evaluation_path = temporary_basedir / "evaluations"
result = runner.invoke(
evaluate,
[
"-vv",
"drive",
f"--predictions={predictions_file}",
f"--output-folder={evaluation_path}",
"--threshold=test",
],
)
_assert_exit_0(result)
assert (evaluation_path / "evaluation.meta.json").exists()
assert (evaluation_path / "comparison.pdf").exists()
assert (evaluation_path / "comparison.rst").exists()
assert (evaluation_path / "train.csv").exists()
assert (evaluation_path / "train").exists()
assert (evaluation_path / "test.csv").exists()
assert (evaluation_path / "test").exists()
keywords = {
r"^Analyzing split `train`...$": 1,
r"^Analyzing split `test`...$": 1,
r"^Creating and saving plot at .*$": 1,
r"^Tabulating performance summary...": 1,
r"^Saving table at .*$": 1,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.skip_if_rc_var_not_set("datadir.drive")
def test_experiment(temporary_basedir):
from mednet.libs.segmentation.scripts.experiment import experiment
runner = CliRunner()
output_folder = temporary_basedir / "experiment"
num_epochs = 2
result = runner.invoke(
experiment,
[
"-vv",
"lwnet",
"drive",
f"--epochs={num_epochs}",
f"--output-folder={str(output_folder)}",
],
)
_assert_exit_0(result)
assert (output_folder / "model" / "meta.json").exists()
assert (
output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt"
).exists()
assert (output_folder / "predictions" / "predictions.json").exists()
assert (output_folder / "predictions" / "predictions.meta.json").exists()
# Need to glob because we cannot be sure of the checkpoint with lowest validation loss
assert (
len(
list(
(output_folder / "model").glob(
"model-at-lowest-validation-loss-epoch=*.ckpt",
),
),
)
== 1
)
assert (output_folder / "model" / "trainlog.pdf").exists()
assert (
len(
list(
(output_folder / "model" / "logs").glob(
"events.out.tfevents.*",
),
),
)
== 1
)
assert (output_folder / "evaluation" / "evaluation.meta.json").exists()
assert (output_folder / "evaluation" / "comparison.pdf").exists()
assert (output_folder / "evaluation" / "comparison.rst").exists()
assert (output_folder / "evaluation" / "train.csv").exists()
assert (output_folder / "evaluation" / "train").exists()
assert (output_folder / "evaluation" / "test.csv").exists()
assert (output_folder / "evaluation" / "test").exists()
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import math
import random
import unittest
import numpy
import pytest
import torch
from mednet.libs.segmentation.engine.evaluator import (
sample_measures_for_threshold,
)
from mednet.libs.segmentation.utils.measure import (
auc,
base_measures,
bayesian_measures,
beta_credible_region,
)
class TestFrequentist(unittest.TestCase):
"""Unit test for frequentist base measures."""
def setUp(self):
self.tp = random.randint(1, 100)
self.fp = random.randint(1, 100)
self.tn = random.randint(1, 100)
self.fn = random.randint(1, 100)
def test_precision(self):
precision = base_measures(self.tp, self.fp, self.tn, self.fn)[0]
self.assertEqual((self.tp) / (self.tp + self.fp), precision)
def test_recall(self):
recall = base_measures(self.tp, self.fp, self.tn, self.fn)[1]
self.assertEqual((self.tp) / (self.tp + self.fn), recall)
def test_specificity(self):
specificity = base_measures(self.tp, self.fp, self.tn, self.fn)[2]
self.assertEqual((self.tn) / (self.tn + self.fp), specificity)
def test_accuracy(self):
accuracy = base_measures(self.tp, self.fp, self.tn, self.fn)[3]
self.assertEqual(
(self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn),
accuracy,
)
def test_jaccard(self):
jaccard = base_measures(self.tp, self.fp, self.tn, self.fn)[4]
self.assertEqual(self.tp / (self.tp + self.fp + self.fn), jaccard)
def test_f1(self):
p, r, s, a, j, f1 = base_measures(self.tp, self.fp, self.tn, self.fn)
self.assertEqual(
(2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn), f1
)
self.assertAlmostEqual((2 * p * r) / (p + r), f1) # base definition
class TestBayesian:
"""Unit test for bayesian base measures."""
def mean(self, k, i, lambda_):
return (k + lambda_) / (k + i + 2 * lambda_)
def mode1(self, k, i, lambda_): # (k+lambda_), (i+lambda_) > 1
return (k + lambda_ - 1) / (k + i + 2 * lambda_ - 2)
def test_beta_credible_region_base(self):
k = 40
i = 10
lambda_ = 0.5
cover = 0.95
got = beta_credible_region(k, i, lambda_, cover)
# mean, mode, lower, upper
exp = (
self.mean(k, i, lambda_),
self.mode1(k, i, lambda_),
0.6741731038857685,
0.8922659692341358,
)
assert numpy.isclose(got, exp).all(), f"{got} <> {exp}"
def test_beta_credible_region_small_k(self):
k = 4
i = 1
lambda_ = 0.5
cover = 0.95
got = beta_credible_region(k, i, lambda_, cover)
# mean, mode, lower, upper
exp = (
self.mean(k, i, lambda_),
self.mode1(k, i, lambda_),
0.37137359936800574,
0.9774872340008449,
)
assert numpy.isclose(got, exp).all(), f"{got} <> {exp}"
def test_beta_credible_region_precision_jeffrey(self):
# simulation of situation for precision TP == FP == 0, Jeffrey's prior
k = 0
i = 0
lambda_ = 0.5
cover = 0.95
got = beta_credible_region(k, i, lambda_, cover)
# mean, mode, lower, upper
exp = (
self.mean(k, i, lambda_),
0.0,
0.0015413331334360135,
0.998458666866564,
)
assert numpy.isclose(got, exp).all(), f"{got} <> {exp}"
def test_beta_credible_region_precision_flat(self):
# simulation of situation for precision TP == FP == 0, flat prior
k = 0
i = 0
lambda_ = 1.0
cover = 0.95
got = beta_credible_region(k, i, lambda_, cover)
# mean, mode, lower, upper
exp = (self.mean(k, i, lambda_), 0.0, 0.025000000000000022, 0.975)
assert numpy.isclose(got, exp).all(), f"{got} <> {exp}"
def test_bayesian_measures(self):
tp = random.randint(100000, 1000000)
fp = random.randint(100000, 1000000)
tn = random.randint(100000, 1000000)
fn = random.randint(100000, 1000000)
_prec, _rec, _spec, _acc, _jac, _f1 = base_measures(tp, fp, tn, fn)
prec, rec, spec, acc, jac, f1 = bayesian_measures(
tp, fp, tn, fn, 0.5, 0.95
)
# Notice that for very large k and l, the base frequentist measures
# should be approximately the same as the bayesian mean and mode
# extracted from the beta posterior. We test that here.
assert numpy.isclose(
_prec, prec[0]
), f"freq: {_prec} <> bays: {prec[0]}"
assert numpy.isclose(
_prec, prec[1]
), f"freq: {_prec} <> bays: {prec[1]}"
assert numpy.isclose(_rec, rec[0]), f"freq: {_rec} <> bays: {rec[0]}"
assert numpy.isclose(_rec, rec[1]), f"freq: {_rec} <> bays: {rec[1]}"
assert numpy.isclose(
_spec, spec[0]
), f"freq: {_spec} <> bays: {spec[0]}"
assert numpy.isclose(
_spec, spec[1]
), f"freq: {_spec} <> bays: {spec[1]}"
assert numpy.isclose(_acc, acc[0]), f"freq: {_acc} <> bays: {acc[0]}"
assert numpy.isclose(_acc, acc[1]), f"freq: {_acc} <> bays: {acc[1]}"
assert numpy.isclose(_jac, jac[0]), f"freq: {_jac} <> bays: {jac[0]}"
assert numpy.isclose(_jac, jac[1]), f"freq: {_jac} <> bays: {jac[1]}"
assert numpy.isclose(_f1, f1[0]), f"freq: {_f1} <> bays: {f1[0]}"
assert numpy.isclose(_f1, f1[1]), f"freq: {_f1} <> bays: {f1[1]}"
# We also test that the interval in question includes the mode and the
# mean in this case.
assert (prec[2] < prec[1]) and (
prec[1] < prec[3]
), f"precision is out of bounds {_prec[2]} < {_prec[1]} < {_prec[3]}"
assert (rec[2] < rec[1]) and (
rec[1] < rec[3]
), f"recall is out of bounds {_rec[2]} < {_rec[1]} < {_rec[3]}"
assert (spec[2] < spec[1]) and (
spec[1] < spec[3]
), f"specif. is out of bounds {_spec[2]} < {_spec[1]} < {_spec[3]}"
assert (acc[2] < acc[1]) and (
acc[1] < acc[3]
), f"accuracy is out of bounds {_acc[2]} < {_acc[1]} < {_acc[3]}"
assert (jac[2] < jac[1]) and (
jac[1] < jac[3]
), f"jaccard is out of bounds {_jac[2]} < {_jac[1]} < {_jac[3]}"
assert (f1[2] < f1[1]) and (
f1[1] < f1[3]
), f"f1-score is out of bounds {_f1[2]} < {_f1[1]} < {_f1[3]}"
def test_auc():
# basic tests
assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0, 1.0]), 1.0)
assert math.isclose(
auc([0.0, 0.5, 1.0], [1.0, 0.5, 0.0]), 0.5, rel_tol=0.001
)
assert math.isclose(
auc([0.0, 0.5, 1.0], [0.0, 0.0, 0.0]), 0.0, rel_tol=0.001
)
assert math.isclose(
auc([0.0, 0.5, 1.0], [0.0, 1.0, 0.0]), 0.5, rel_tol=0.001
)
assert math.isclose(
auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001
)
assert math.isclose(
auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001
)
# reversing tht is also true
assert math.isclose(auc([0.0, 0.5, 1.0][::-1], [1.0, 1.0, 1.0][::-1]), 1.0)
assert math.isclose(
auc([0.0, 0.5, 1.0][::-1], [1.0, 0.5, 0.0][::-1]), 0.5, rel_tol=0.001
)
assert math.isclose(
auc([0.0, 0.5, 1.0][::-1], [0.0, 0.0, 0.0][::-1]), 0.0, rel_tol=0.001
)
assert math.isclose(
auc([0.0, 0.5, 1.0][::-1], [0.0, 1.0, 0.0][::-1]), 0.5, rel_tol=0.001
)
assert math.isclose(
auc([0.0, 0.5, 1.0][::-1], [0.0, 0.5, 0.0][::-1]), 0.25, rel_tol=0.001
)
assert math.isclose(
auc([0.0, 0.5, 1.0][::-1], [0.0, 0.5, 0.0][::-1]), 0.25, rel_tol=0.001
)
def test_auc_raises_value_error():
with pytest.raises(
ValueError, match=r".*neither increasing nor decreasing.*"
):
# x is **not** monotonically increasing or decreasing
assert math.isclose(auc([0.0, 0.5, 0.0], [1.0, 1.0, 1.0]), 1.0)
def test_auc_raises_assertion_error():
with pytest.raises(AssertionError, match=r".*must have the same length.*"):
# x is **not** the same size as y
assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0]), 1.0)
def test_sample_measures_mask_checkerbox():
prediction = torch.ones((4, 4), dtype=float)
ground_truth = torch.ones((4, 4), dtype=float)
ground_truth[2:, :2] = 0.0
ground_truth[:2, 2:] = 0.0
mask = torch.zeros((4, 4), dtype=float)
mask[1:3, 1:3] = 1.0
threshold = 0.5
# with this configuration, this should be the correct count
tp = 2
fp = 2
tn = 0
fn = 0
assert (tp, fp, tn, fn) == sample_measures_for_threshold(
prediction, ground_truth, mask, threshold
)
def test_sample_measures_mask_cross():
prediction = torch.ones((10, 10), dtype=float)
prediction[0, :] = 0.0
prediction[9, :] = 0.0
ground_truth = torch.ones((10, 10), dtype=float)
ground_truth[:5,] = 0.0 # lower part is not to be set
mask = torch.zeros((10, 10), dtype=float)
mask[(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)] = 1.0
mask[(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), (9, 8, 7, 6, 5, 4, 3, 2, 1, 0)] = 1.0
threshold = 0.5
# with this configuration, this should be the correct count
tp = 8
fp = 8
tn = 2
fn = 2
assert (tp, fp, tn, fn) == sample_measures_for_threshold(
prediction, ground_truth, mask, threshold
)
def test_sample_measures_mask_border():
prediction = torch.zeros((10, 10), dtype=float)
prediction[:, 4] = 1.0
prediction[:, 5] = 1.0
prediction[0, 4] = 0.0
prediction[8, 4] = 0.0
prediction[1, 6] = 1.0
ground_truth = torch.zeros((10, 10), dtype=float)
ground_truth[:, 4] = 1.0
ground_truth[:, 5] = 1.0
mask = torch.ones((10, 10), dtype=float)
mask[:, 0] = 0.0
mask[0, :] = 0.0
mask[:, 9] = 0.0
mask[9, :] = 0.0
threshold = 0.5
# with this configuration, this should be the correct count
tp = 15
fp = 1
tn = 47
fn = 1
assert (tp, fp, tn, fn) == sample_measures_for_threshold(
prediction, ground_truth, mask, threshold
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment