diff --git a/helpers/extract_hdf5_images.py b/helpers/extract_hdf5_images.py new file mode 100644 index 0000000000000000000000000000000000000000..a093412d6d6a8e2a50b45ac620fc5b65d95e525e --- /dev/null +++ b/helpers/extract_hdf5_images.py @@ -0,0 +1,56 @@ +import argparse +import pathlib + +import h5py +import torch +import torchvision.transforms.functional as tf + + +def save_images(tensors_dict, output_dir): + for key, value in tensors_dict.items(): + output_file = output_dir / pathlib.Path(key + ".png") + tf.to_pil_image(value).save(output_file) + print(f"Saved file {output_file}") + + +def extract_images_from_hdf5(hdf5_file): + tensors_dict = {} + with h5py.File(hdf5_file, "r") as f: + tensors_dict["image"] = torch.from_numpy(f.get("img")[:]) + tensors_dict["target"] = torch.from_numpy(f.get("target")[:]) + tensors_dict["mask"] = torch.from_numpy(f.get("mask")[:]) + + return tensors_dict + + +def get_hdf5_files(directory, recursive=False) -> list[pathlib.Path]: + return directory.glob("**/*.hdf5", recursive=recursive) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_dir", + type=pathlib.Path, + help="Directory in which hdf5 files are located.", + ) + parser.add_argument( + "--recursive", + "-r", + action="store_true", + help="Set to true to search recursively in the input directory.", + ) + + args = parser.parse_args() + + hdf5_files = get_hdf5_files(args.input_dir, recursive=args.recursive) + for hdf5_file in hdf5_files: + tensors_dict = extract_images_from_hdf5(hdf5_file) + + save_dir = pathlib.Path(hdf5_file).with_suffix("").with_suffix("") + save_dir.mkdir(parents=True, exist_ok=True) + save_images(tensors_dict, save_dir) + + +if __name__ == "__main__": + main() diff --git a/src/mednet/libs/classification/tests/test_cli_classification.py b/src/mednet/libs/classification/tests/test_cli_classification.py index 99e70d8fbe19677582a5cb649d29081724fcc8d3..8c76cda2b0f8146d5a1e5c9a2024e25dc2a8857a 100644 --- a/src/mednet/libs/classification/tests/test_cli_classification.py +++ b/src/mednet/libs/classification/tests/test_cli_classification.py @@ -389,7 +389,7 @@ def test_evaluate_pasa_montgomery(temporary_basedir): "-vv", "montgomery", f"--predictions={predictions_file}", - f"--output-folder={str(temporary_basedir)}", + f"--output-folder={temporary_basedir}", "--threshold=test", ], ) diff --git a/src/mednet/libs/segmentation/scripts/cli.py b/src/mednet/libs/segmentation/scripts/cli.py index 49ba9821e80452d21428fd9b5d04dfbd4ee91150..a4a7730a39efe838b2a36f147e32f95cd8ffdc2d 100644 --- a/src/mednet/libs/segmentation/scripts/cli.py +++ b/src/mednet/libs/segmentation/scripts/cli.py @@ -12,9 +12,8 @@ from . import ( # compare, config, database, + evaluate, predict, - # evaluate, - # experiment, # mkmask, # significance, train, @@ -34,12 +33,11 @@ def segmentation(): # segmentation.add_command(compare.compare) segmentation.add_command(config.config) segmentation.add_command(database.database) -# segmentation.add_command(evaluate.evaluate) -# segmentation.add_command(experiment.experiment) # segmentation.add_command(mkmask.mkmask) # segmentation.add_command(significance.significance) segmentation.add_command(train.train) segmentation.add_command(predict.predict) +segmentation.add_command(evaluate.evaluate) segmentation.add_command( importlib.import_module( "mednet.libs.common.scripts.train_analysis", diff --git a/src/mednet/libs/segmentation/scripts/evaluate.py b/src/mednet/libs/segmentation/scripts/evaluate.py index e5b85b96a8dafe091222ea00a4cf2504f1db8f00..84ba542c6dff4a4547846c80513be9735201b386 100644 --- a/src/mednet/libs/segmentation/scripts/evaluate.py +++ b/src/mednet/libs/segmentation/scripts/evaluate.py @@ -19,7 +19,7 @@ from mednet.libs.segmentation.engine.evaluator import run @click.command( - entry_point_group="deepdraw.config", + entry_point_group="mednet.libs.segmentation.config", cls=ConfigCommand, epilog="""Examples: @@ -63,19 +63,13 @@ from mednet.libs.segmentation.engine.evaluator import run "-o", help="Directory in which to store results (created if does not exist)", required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), default="results", - type=click.Path(), - cls=ResourceOption, -) -@click.option( - "--dataset", - "-d", - help="A torch.utils.data.dataset.Dataset instance implementing a dataset " - "to be used for evaluation purposes, possibly including all pre-processing " - "pipelines required or, optionally, a dictionary mapping string keys to " - "torch.utils.data.dataset.Dataset instances. All keys that do not start " - "with an underscore (_) will be processed.", - required=True, cls=ResourceOption, ) @click.option( @@ -144,17 +138,15 @@ from mednet.libs.segmentation.engine.evaluator import run cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) -@click.pass_context def evaluate( - ctx, predictions: pathlib.Path, output_folder: pathlib.Path, threshold: str | float, - second_annotator, + # second_annotator, overlayed, steps, parallel, - **kwargs, + **_, # ignored ): # numpydoc ignore=PR01 """Evaluate predictions (from a model) on a segmentation task.""" diff --git a/src/mednet/libs/segmentation/tests/test_cli_segmentation.py b/src/mednet/libs/segmentation/tests/test_cli_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..412509c8d98d5938f0f359732c69ddaea305a1d3 --- /dev/null +++ b/src/mednet/libs/segmentation/tests/test_cli_segmentation.py @@ -0,0 +1,446 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Tests for our CLI applications.""" + +import contextlib +import re + +import pytest +from click.testing import CliRunner + + +@contextlib.contextmanager +def stdout_logging(): + # copy logging messages to std out + + import io + import logging + + buf = io.StringIO() + ch = logging.StreamHandler(buf) + ch.setFormatter(logging.Formatter("%(message)s")) + ch.setLevel(logging.INFO) + logger = logging.getLogger("mednet") + logger.addHandler(ch) + yield buf + logger.removeHandler(ch) + + +def _assert_exit_0(result): + assert ( + result.exit_code == 0 + ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" + + +def _check_help(entry_point): + runner = CliRunner() + result = runner.invoke(entry_point, ["--help"]) + _assert_exit_0(result) + assert result.output.startswith("Usage:") + + +def test_config_help(): + from mednet.libs.segmentation.scripts.config import config + + _check_help(config) + + +def test_config_list_help(): + from mednet.libs.segmentation.scripts.config import list_ + + _check_help(list_) + + +def test_config_list(): + from mednet.libs.segmentation.scripts.config import list_ + + runner = CliRunner() + result = runner.invoke(list_) + _assert_exit_0(result) + assert "module: mednet.libs.segmentation.config.data" in result.output + assert "module: mednet.libs.segmentation.config.models" in result.output + + +def test_config_list_v(): + from mednet.libs.segmentation.scripts.config import list_ + + result = CliRunner().invoke(list_, ["--verbose"]) + _assert_exit_0(result) + assert "module: mednet.libs.segmentation.config.data" in result.output + assert "module: mednet.libs.segmentation.config.models" in result.output + + +def test_config_describe_help(): + from mednet.libs.segmentation.scripts.config import describe + + _check_help(describe) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.drive") +def test_config_describe_drive(): + from mednet.libs.segmentation.scripts.config import describe + + runner = CliRunner() + result = runner.invoke(describe, ["drive"]) + _assert_exit_0(result) + assert ( + "DRIVE dataset for Vessel Segmentation (default protocol)." + in result.output + ) + + +def test_database_help(): + from mednet.libs.segmentation.scripts.database import database + + _check_help(database) + + +def test_datamodule_list_help(): + from mednet.libs.segmentation.scripts.database import list_ + + _check_help(list_) + + +def test_datamodule_list(): + from mednet.libs.segmentation.scripts.database import list_ + + runner = CliRunner() + result = runner.invoke(list_) + _assert_exit_0(result) + assert result.output.startswith("Available databases:") + + +def test_datamodule_check_help(): + from mednet.libs.segmentation.scripts.database import check + + _check_help(check) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.drive") +def test_database_check(): + from mednet.libs.segmentation.scripts.database import check + + runner = CliRunner() + result = runner.invoke(check, ["--verbose", "--limit=1", "drive"]) + _assert_exit_0(result) + + +def test_main_help(): + from mednet.libs.segmentation.scripts.cli import segmentation + + _check_help(segmentation) + + +def test_train_help(): + from mednet.libs.segmentation.scripts.train import train + + _check_help(train) + + +def _str_counter(substr, s): + return sum(1 for _ in re.finditer(substr, s, re.MULTILINE)) + + +def test_predict_help(): + from mednet.libs.segmentation.scripts.predict import predict + + _check_help(predict) + + +def test_evaluate_help(): + from mednet.libs.segmentation.scripts.evaluate import evaluate + + _check_help(evaluate) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.drive") +def test_train_lwnet_drive(temporary_basedir): + from mednet.libs.common.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) + from mednet.libs.segmentation.scripts.train import train + + runner = CliRunner() + + with stdout_logging() as buf: + output_folder = temporary_basedir / "results" + result = runner.invoke( + train, + [ + "lwnet", + "drive", + "-vv", + "--epochs=1", + "--batch-size=1", + f"--output-folder={str(output_folder)}", + ], + ) + _assert_exit_0(result) + + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + + assert ( + len(list((output_folder / "logs").glob("events.out.tfevents.*"))) + == 1 + ) + assert (output_folder / "meta.json").exists() + + keywords = { + r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, + r"^Training for at most 1 epochs.$": 1, + r"^Uninitialised lwnet model - computing z-norm factors from train dataloader.$": 1, + r"^Writing run metadata at.*$": 1, + r"^Dataset `train` is already setup. Not re-instantiating it.$": 3, + } + buf.seek(0) + logging_output = buf.read() + + for k, v in keywords.items(): + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, logging_output)}) " + f"instead of the expected {v}:\nOutput:\n{logging_output}" + ) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.drive") +def test_train_lwnet_drive_from_checkpoint(temporary_basedir): + from mednet.libs.common.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) + from mednet.libs.segmentation.scripts.train import train + + runner = CliRunner() + + output_folder = temporary_basedir / "results" / "lwnet_checkpoint" + result0 = runner.invoke( + train, + [ + "lwnet", + "drive", + "-vv", + "--epochs=1", + "--batch-size=1", + f"--output-folder={str(output_folder)}", + ], + ) + _assert_exit_0(result0) + + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + + assert (output_folder / "meta.json").exists() + assert ( + len(list((output_folder / "logs").glob("events.out.tfevents.*"))) == 1 + ) + + with stdout_logging() as buf: + result = runner.invoke( + train, + [ + "lwnet", + "drive", + "-vv", + "--epochs=2", + "--batch-size=1", + f"--output-folder={output_folder}", + ], + ) + _assert_exit_0(result) + + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + + assert (output_folder / "meta.json").exists() + assert ( + len(list((output_folder / "logs").glob("events.out.tfevents.*"))) + == 2 + ) + + keywords = { + r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, + r"^Training for at most 2 epochs.$": 1, + r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, + r"^Writing run metadata at.*$": 1, + r"^Dataset `train` is already setup. Not re-instantiating it.$": 3, + r"^Restoring normalizer from checkpoint.$": 1, + } + buf.seek(0) + logging_output = buf.read() + + for k, v in keywords.items(): + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, logging_output)}) " + f"instead of the expected {v}:\nOutput:\n{logging_output}" + ) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.drive") +def test_predict_lwnet_drive(temporary_basedir, datadir): + from mednet.libs.common.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) + from mednet.libs.segmentation.scripts.predict import predict + + runner = CliRunner() + + with stdout_logging() as buf: + output = temporary_basedir / "predictions" + last = _get_checkpoint_from_alias( + temporary_basedir / "results", + "periodic", + ) + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + result = runner.invoke( + predict, + [ + "lwnet", + "drive", + "-vv", + "--batch-size=1", + f"--weight={str(last)}", + f"--output-folder={str(output)}", + ], + ) + _assert_exit_0(result) + + assert output.exists() + + keywords = { + r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 2, + r"^Loading checkpoint from .*$": 1, + r"^Restoring normalizer from checkpoint.$": 1, + r"^Running prediction on `train` split...$": 1, + r"^Running prediction on `test` split...$": 1, + r"^Predictions saved to .*$": 1, + } + + buf.seek(0) + logging_output = buf.read() + + for k, v in keywords.items(): + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, logging_output)}) " + f"instead of the expected {v}:\nOutput:\n{logging_output}" + ) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.drive") +def test_evaluate_lwnet_drive(temporary_basedir): + from mednet.libs.segmentation.scripts.evaluate import evaluate + + runner = CliRunner() + + with stdout_logging() as buf: + prediction_path = temporary_basedir / "predictions" + predictions_file = prediction_path / "predictions.json" + evaluation_path = temporary_basedir / "evaluations" + result = runner.invoke( + evaluate, + [ + "-vv", + "drive", + f"--predictions={predictions_file}", + f"--output-folder={evaluation_path}", + "--threshold=test", + ], + ) + _assert_exit_0(result) + + assert (evaluation_path / "evaluation.meta.json").exists() + assert (evaluation_path / "comparison.pdf").exists() + assert (evaluation_path / "comparison.rst").exists() + assert (evaluation_path / "train.csv").exists() + assert (evaluation_path / "train").exists() + assert (evaluation_path / "test.csv").exists() + assert (evaluation_path / "test").exists() + + keywords = { + r"^Analyzing split `train`...$": 1, + r"^Analyzing split `test`...$": 1, + r"^Creating and saving plot at .*$": 1, + r"^Tabulating performance summary...": 1, + r"^Saving table at .*$": 1, + } + buf.seek(0) + logging_output = buf.read() + + for k, v in keywords.items(): + assert _str_counter(k, logging_output) == v, ( + f"Count for string '{k}' appeared " + f"({_str_counter(k, logging_output)}) " + f"instead of the expected {v}:\nOutput:\n{logging_output}" + ) + + +@pytest.mark.skip_if_rc_var_not_set("datadir.drive") +def test_experiment(temporary_basedir): + from mednet.libs.segmentation.scripts.experiment import experiment + + runner = CliRunner() + + output_folder = temporary_basedir / "experiment" + num_epochs = 2 + result = runner.invoke( + experiment, + [ + "-vv", + "lwnet", + "drive", + f"--epochs={num_epochs}", + f"--output-folder={str(output_folder)}", + ], + ) + _assert_exit_0(result) + + assert (output_folder / "model" / "meta.json").exists() + assert ( + output_folder / "model" / f"model-at-epoch={num_epochs-1}.ckpt" + ).exists() + assert (output_folder / "predictions" / "predictions.json").exists() + assert (output_folder / "predictions" / "predictions.meta.json").exists() + + # Need to glob because we cannot be sure of the checkpoint with lowest validation loss + assert ( + len( + list( + (output_folder / "model").glob( + "model-at-lowest-validation-loss-epoch=*.ckpt", + ), + ), + ) + == 1 + ) + assert (output_folder / "model" / "trainlog.pdf").exists() + assert ( + len( + list( + (output_folder / "model" / "logs").glob( + "events.out.tfevents.*", + ), + ), + ) + == 1 + ) + assert (output_folder / "evaluation" / "evaluation.meta.json").exists() + assert (output_folder / "evaluation" / "comparison.pdf").exists() + assert (output_folder / "evaluation" / "comparison.rst").exists() + assert (output_folder / "evaluation" / "train.csv").exists() + assert (output_folder / "evaluation" / "train").exists() + assert (output_folder / "evaluation" / "test.csv").exists() + assert (output_folder / "evaluation" / "test").exists() diff --git a/src/mednet/libs/segmentation/tests/test_measures.py b/src/mednet/libs/segmentation/tests/test_measures.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1b40699250368664a2dc1ec1c2f0fe25db05fb --- /dev/null +++ b/src/mednet/libs/segmentation/tests/test_measures.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import math +import random +import unittest + +import numpy +import pytest +import torch +from mednet.libs.segmentation.engine.evaluator import ( + sample_measures_for_threshold, +) +from mednet.libs.segmentation.utils.measure import ( + auc, + base_measures, + bayesian_measures, + beta_credible_region, +) + + +class TestFrequentist(unittest.TestCase): + """Unit test for frequentist base measures.""" + + def setUp(self): + self.tp = random.randint(1, 100) + self.fp = random.randint(1, 100) + self.tn = random.randint(1, 100) + self.fn = random.randint(1, 100) + + def test_precision(self): + precision = base_measures(self.tp, self.fp, self.tn, self.fn)[0] + self.assertEqual((self.tp) / (self.tp + self.fp), precision) + + def test_recall(self): + recall = base_measures(self.tp, self.fp, self.tn, self.fn)[1] + self.assertEqual((self.tp) / (self.tp + self.fn), recall) + + def test_specificity(self): + specificity = base_measures(self.tp, self.fp, self.tn, self.fn)[2] + self.assertEqual((self.tn) / (self.tn + self.fp), specificity) + + def test_accuracy(self): + accuracy = base_measures(self.tp, self.fp, self.tn, self.fn)[3] + self.assertEqual( + (self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn), + accuracy, + ) + + def test_jaccard(self): + jaccard = base_measures(self.tp, self.fp, self.tn, self.fn)[4] + self.assertEqual(self.tp / (self.tp + self.fp + self.fn), jaccard) + + def test_f1(self): + p, r, s, a, j, f1 = base_measures(self.tp, self.fp, self.tn, self.fn) + self.assertEqual( + (2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn), f1 + ) + self.assertAlmostEqual((2 * p * r) / (p + r), f1) # base definition + + +class TestBayesian: + """Unit test for bayesian base measures.""" + + def mean(self, k, i, lambda_): + return (k + lambda_) / (k + i + 2 * lambda_) + + def mode1(self, k, i, lambda_): # (k+lambda_), (i+lambda_) > 1 + return (k + lambda_ - 1) / (k + i + 2 * lambda_ - 2) + + def test_beta_credible_region_base(self): + k = 40 + i = 10 + lambda_ = 0.5 + cover = 0.95 + got = beta_credible_region(k, i, lambda_, cover) + # mean, mode, lower, upper + exp = ( + self.mean(k, i, lambda_), + self.mode1(k, i, lambda_), + 0.6741731038857685, + 0.8922659692341358, + ) + assert numpy.isclose(got, exp).all(), f"{got} <> {exp}" + + def test_beta_credible_region_small_k(self): + k = 4 + i = 1 + lambda_ = 0.5 + cover = 0.95 + got = beta_credible_region(k, i, lambda_, cover) + # mean, mode, lower, upper + exp = ( + self.mean(k, i, lambda_), + self.mode1(k, i, lambda_), + 0.37137359936800574, + 0.9774872340008449, + ) + assert numpy.isclose(got, exp).all(), f"{got} <> {exp}" + + def test_beta_credible_region_precision_jeffrey(self): + # simulation of situation for precision TP == FP == 0, Jeffrey's prior + k = 0 + i = 0 + lambda_ = 0.5 + cover = 0.95 + got = beta_credible_region(k, i, lambda_, cover) + # mean, mode, lower, upper + exp = ( + self.mean(k, i, lambda_), + 0.0, + 0.0015413331334360135, + 0.998458666866564, + ) + assert numpy.isclose(got, exp).all(), f"{got} <> {exp}" + + def test_beta_credible_region_precision_flat(self): + # simulation of situation for precision TP == FP == 0, flat prior + k = 0 + i = 0 + lambda_ = 1.0 + cover = 0.95 + got = beta_credible_region(k, i, lambda_, cover) + # mean, mode, lower, upper + exp = (self.mean(k, i, lambda_), 0.0, 0.025000000000000022, 0.975) + assert numpy.isclose(got, exp).all(), f"{got} <> {exp}" + + def test_bayesian_measures(self): + tp = random.randint(100000, 1000000) + fp = random.randint(100000, 1000000) + tn = random.randint(100000, 1000000) + fn = random.randint(100000, 1000000) + + _prec, _rec, _spec, _acc, _jac, _f1 = base_measures(tp, fp, tn, fn) + prec, rec, spec, acc, jac, f1 = bayesian_measures( + tp, fp, tn, fn, 0.5, 0.95 + ) + + # Notice that for very large k and l, the base frequentist measures + # should be approximately the same as the bayesian mean and mode + # extracted from the beta posterior. We test that here. + assert numpy.isclose( + _prec, prec[0] + ), f"freq: {_prec} <> bays: {prec[0]}" + assert numpy.isclose( + _prec, prec[1] + ), f"freq: {_prec} <> bays: {prec[1]}" + assert numpy.isclose(_rec, rec[0]), f"freq: {_rec} <> bays: {rec[0]}" + assert numpy.isclose(_rec, rec[1]), f"freq: {_rec} <> bays: {rec[1]}" + assert numpy.isclose( + _spec, spec[0] + ), f"freq: {_spec} <> bays: {spec[0]}" + assert numpy.isclose( + _spec, spec[1] + ), f"freq: {_spec} <> bays: {spec[1]}" + assert numpy.isclose(_acc, acc[0]), f"freq: {_acc} <> bays: {acc[0]}" + assert numpy.isclose(_acc, acc[1]), f"freq: {_acc} <> bays: {acc[1]}" + assert numpy.isclose(_jac, jac[0]), f"freq: {_jac} <> bays: {jac[0]}" + assert numpy.isclose(_jac, jac[1]), f"freq: {_jac} <> bays: {jac[1]}" + assert numpy.isclose(_f1, f1[0]), f"freq: {_f1} <> bays: {f1[0]}" + assert numpy.isclose(_f1, f1[1]), f"freq: {_f1} <> bays: {f1[1]}" + + # We also test that the interval in question includes the mode and the + # mean in this case. + assert (prec[2] < prec[1]) and ( + prec[1] < prec[3] + ), f"precision is out of bounds {_prec[2]} < {_prec[1]} < {_prec[3]}" + assert (rec[2] < rec[1]) and ( + rec[1] < rec[3] + ), f"recall is out of bounds {_rec[2]} < {_rec[1]} < {_rec[3]}" + assert (spec[2] < spec[1]) and ( + spec[1] < spec[3] + ), f"specif. is out of bounds {_spec[2]} < {_spec[1]} < {_spec[3]}" + assert (acc[2] < acc[1]) and ( + acc[1] < acc[3] + ), f"accuracy is out of bounds {_acc[2]} < {_acc[1]} < {_acc[3]}" + assert (jac[2] < jac[1]) and ( + jac[1] < jac[3] + ), f"jaccard is out of bounds {_jac[2]} < {_jac[1]} < {_jac[3]}" + assert (f1[2] < f1[1]) and ( + f1[1] < f1[3] + ), f"f1-score is out of bounds {_f1[2]} < {_f1[1]} < {_f1[3]}" + + +def test_auc(): + # basic tests + assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0, 1.0]), 1.0) + assert math.isclose( + auc([0.0, 0.5, 1.0], [1.0, 0.5, 0.0]), 0.5, rel_tol=0.001 + ) + assert math.isclose( + auc([0.0, 0.5, 1.0], [0.0, 0.0, 0.0]), 0.0, rel_tol=0.001 + ) + assert math.isclose( + auc([0.0, 0.5, 1.0], [0.0, 1.0, 0.0]), 0.5, rel_tol=0.001 + ) + assert math.isclose( + auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001 + ) + assert math.isclose( + auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001 + ) + + # reversing tht is also true + assert math.isclose(auc([0.0, 0.5, 1.0][::-1], [1.0, 1.0, 1.0][::-1]), 1.0) + assert math.isclose( + auc([0.0, 0.5, 1.0][::-1], [1.0, 0.5, 0.0][::-1]), 0.5, rel_tol=0.001 + ) + assert math.isclose( + auc([0.0, 0.5, 1.0][::-1], [0.0, 0.0, 0.0][::-1]), 0.0, rel_tol=0.001 + ) + assert math.isclose( + auc([0.0, 0.5, 1.0][::-1], [0.0, 1.0, 0.0][::-1]), 0.5, rel_tol=0.001 + ) + assert math.isclose( + auc([0.0, 0.5, 1.0][::-1], [0.0, 0.5, 0.0][::-1]), 0.25, rel_tol=0.001 + ) + assert math.isclose( + auc([0.0, 0.5, 1.0][::-1], [0.0, 0.5, 0.0][::-1]), 0.25, rel_tol=0.001 + ) + + +def test_auc_raises_value_error(): + with pytest.raises( + ValueError, match=r".*neither increasing nor decreasing.*" + ): + # x is **not** monotonically increasing or decreasing + assert math.isclose(auc([0.0, 0.5, 0.0], [1.0, 1.0, 1.0]), 1.0) + + +def test_auc_raises_assertion_error(): + with pytest.raises(AssertionError, match=r".*must have the same length.*"): + # x is **not** the same size as y + assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0]), 1.0) + + +def test_sample_measures_mask_checkerbox(): + prediction = torch.ones((4, 4), dtype=float) + ground_truth = torch.ones((4, 4), dtype=float) + ground_truth[2:, :2] = 0.0 + ground_truth[:2, 2:] = 0.0 + mask = torch.zeros((4, 4), dtype=float) + mask[1:3, 1:3] = 1.0 + threshold = 0.5 + + # with this configuration, this should be the correct count + tp = 2 + fp = 2 + tn = 0 + fn = 0 + + assert (tp, fp, tn, fn) == sample_measures_for_threshold( + prediction, ground_truth, mask, threshold + ) + + +def test_sample_measures_mask_cross(): + prediction = torch.ones((10, 10), dtype=float) + prediction[0, :] = 0.0 + prediction[9, :] = 0.0 + ground_truth = torch.ones((10, 10), dtype=float) + ground_truth[:5,] = 0.0 # lower part is not to be set + mask = torch.zeros((10, 10), dtype=float) + mask[(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)] = 1.0 + mask[(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), (9, 8, 7, 6, 5, 4, 3, 2, 1, 0)] = 1.0 + threshold = 0.5 + + # with this configuration, this should be the correct count + tp = 8 + fp = 8 + tn = 2 + fn = 2 + + assert (tp, fp, tn, fn) == sample_measures_for_threshold( + prediction, ground_truth, mask, threshold + ) + + +def test_sample_measures_mask_border(): + prediction = torch.zeros((10, 10), dtype=float) + prediction[:, 4] = 1.0 + prediction[:, 5] = 1.0 + prediction[0, 4] = 0.0 + prediction[8, 4] = 0.0 + prediction[1, 6] = 1.0 + ground_truth = torch.zeros((10, 10), dtype=float) + ground_truth[:, 4] = 1.0 + ground_truth[:, 5] = 1.0 + mask = torch.ones((10, 10), dtype=float) + mask[:, 0] = 0.0 + mask[0, :] = 0.0 + mask[:, 9] = 0.0 + mask[9, :] = 0.0 + threshold = 0.5 + + # with this configuration, this should be the correct count + tp = 15 + fp = 1 + tn = 47 + fn = 1 + + assert (tp, fp, tn, fn) == sample_measures_for_threshold( + prediction, ground_truth, mask, threshold + )