Skip to content
Snippets Groups Projects
Commit 40aa5aba authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[libs.segmentation.scripts.evaluate] Major clean-up

parent 4eb47d19
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -16,9 +16,15 @@ def save_images(tensors_dict, output_dir):
def extract_images_from_hdf5(hdf5_file):
tensors_dict = {}
with h5py.File(hdf5_file, "r") as f:
tensors_dict["image"] = torch.from_numpy(f.get("img")[:])
tensors_dict["target"] = torch.from_numpy(f.get("target")[:])
tensors_dict["mask"] = torch.from_numpy(f.get("mask")[:])
img = f["img"]
assert isinstance(img, h5py.Dataset)
tensors_dict["image"] = torch.from_numpy(img[:])
target = f["target"]
assert isinstance(target, h5py.Dataset)
tensors_dict["target"] = torch.from_numpy(target[:])
mask = f["mask"]
assert isinstance(mask, h5py.Dataset)
tensors_dict["mask"] = torch.from_numpy(mask[:])
return tensors_dict
......
......@@ -128,12 +128,11 @@ def evaluate(
tabulate_results,
)
with predictions.open("r") as f:
predict_data = json.load(f)
evaluation_file = output_folder / "evaluation.json"
# register metadata
save_json_metadata(
output_file=output_folder / "evaluation.meta.json",
output_file=evaluation_file.with_suffix(".meta.json"),
predictions=str(predictions),
output_folder=str(output_folder),
threshold=threshold,
......@@ -141,6 +140,9 @@ def evaluate(
plot=plot,
)
with predictions.open("r") as f:
predict_data = json.load(f)
if threshold in predict_data:
# it is the name of a split
# first run evaluation for reference dataset
......@@ -171,7 +173,6 @@ def evaluate(
)
# records full result analysis to a JSON file
evaluation_file = output_folder / "evaluation.json"
logger.info(f"Saving evaluation results at `{str(evaluation_file)}`...")
save_json_with_backup(evaluation_file, results)
......
......@@ -4,13 +4,17 @@
"""Defines functionality for the evaluation of predictions."""
import json
import logging
import pathlib
import typing
import credible.curves
import credible.plot
import h5py
import numpy
import numpy.typing
import tabulate
from tqdm import tqdm
logger = logging.getLogger(__name__)
......@@ -370,9 +374,9 @@ def load_count(
data = numpy.zeros((len(thresholds), 4), dtype=numpy.uint64)
for sample in tqdm(predictions, desc="sample"):
with h5py.File(prediction_path / sample[1], "r") as f:
pred = numpy.array(f.get("prediction")) # float32
gt = numpy.array(f.get("target")) # boolean
mask = numpy.array(f.get("mask")) # boolean
pred = numpy.array(f["prediction"]) # float32
gt = numpy.array(f["target"]) # boolean
mask = numpy.array(f["mask"]) # boolean
data += numpy.array(
[get_counts_for_threshold(pred, gt, mask, k) for k in thresholds],
dtype=numpy.uint64,
......@@ -411,7 +415,8 @@ def load_predictions(
# peak prediction size and number of samples
with h5py.File(prediction_path / predictions[0][1], "r") as f:
elements = numpy.array(f.get("prediction").shape).prod()
data: h5py.Dataset = typing.cast(h5py.Dataset, f["prediction"])
elements = numpy.array(data.shape).prod()
size = len(predictions) * elements
logger.info(
f"Data loading will require ({elements} x {len(predictions)} x 5 =) "
......@@ -423,10 +428,10 @@ def load_predictions(
gt_array = numpy.empty((size,), dtype=numpy.bool_)
for i, sample in enumerate(tqdm(predictions, desc="sample")):
with h5py.File(prediction_path / sample[1], "r") as f:
mask = numpy.array(f.get("mask")) # boolean
pred = numpy.array(f.get("prediction")) # float32
mask = numpy.array(f["mask"]) # boolean
pred = numpy.array(f["prediction"]) # float32
pred *= mask.astype(numpy.float32)
gt = numpy.array(f.get("target")) # boolean
gt = numpy.array(f["target"]) # boolean
gt &= mask
pred_array[i * elements : (i + 1) * elements] = pred.flatten()
gt_array[i * elements : (i + 1) * elements] = gt.flatten()
......@@ -461,6 +466,278 @@ def compute_metric(
return numpy.array([metric(*k) for k in counts], dtype=numpy.float64)
def validate_threshold(threshold: float | str, splits: list[str]):
"""Validate the user threshold selection and returns parsed threshold.
Parameters
----------
threshold
The threshold to validate.
splits
List of available splits.
Returns
-------
The validated threshold.
"""
try:
# we try to convert it to float first
threshold = float(threshold)
if threshold < 0.0 or threshold > 1.0:
raise ValueError("Float thresholds must be within range [0.0, 1.0]")
except ValueError:
if threshold not in splits:
raise ValueError(
f"Text thresholds should match dataset names, "
f"but {threshold} is not available among the datasets provided ("
f"({', '.join(splits)})"
)
return threshold
def run(
predictions: pathlib.Path,
steps: int,
threshold: str | float,
metric: SUPPORTED_METRIC_TYPE,
) -> tuple[dict[str, dict[str, typing.Any]], float]:
"""Evaluate a segmentation model.
Parameters
----------
predictions
Path to the file ``predictions.json``, containing the list of
predictions to be evaluated.
steps
The number of steps between ``[0, 1]`` to build a threshold list
from. This list will be applied to the probability outputs and
true/false positive/negative counts generated from those.
threshold
Which threshold to apply when generating unary summaries of the
performance. This can be a value between ``[0, 1]``, or the name
of a split in ``predictions`` where a threshold will be calculated
at.
metric
The name of a supported metric that will be used to evaluate the
best threshold from a threshold-list uniformily split in ``steps``,
and for which unary summaries are generated.
Returns
-------
A JSON-able summary with all figures of merit pre-caculated, for
all splits. This is a dictionary where keys are split-names contained
in ``predictions``, and values are dictionaries with the following
keys:
* ``counts``: dictionary where keys are thresholds, and values are
sequence of integers containing the TP, FP, TN, FN (in this order).
* ``auc_score``: a float indicating the area under the ROC curve
for the split. It is calculated using a trapezoidal rule.
* ``average_precision_score``: a float indicating the area under the
precision-recall curve, calculated using a rectangle rule.
* ``curves``: dictionary with 2 keys:
* ``roc``: dictionary with 3 keys:
* ``fpr``: a list of floats with the false-positive rate
* ``tpr``: a list of floats with the true-positive rate
* ``thresholds``: a list of thresholds uniformily separated by
``steps``, at which both ``fpr`` and ``tpr`` are evaluated.
* ``precision_recall``: a dictionary with 3 keys:
* ``precision``: a list of floats with the precision
* ``recall``: a list of floats with the recall
* ``thresholds``: a list of thresholds uniformily separated by
``steps``, at which both ``precision`` and ``recall`` are
evaluated.
* ``threshold_a_priori``: boolean indicating if the threshold for unary
metrics where computed with a threshold chosen a priori or a
posteriori in this split.
* ``<metric-name>``: a float representing the supported metric at the
threshold that maximizes ``metric``. There will be one entry of this
type for each of the :py:obj:`SUPPORTED_METRIC_TYPE`'s.
Also returns the threshold considered for all splits.
"""
with predictions.open("r") as f:
predict_data = json.load(f)
threshold = validate_threshold(threshold, predict_data)
threshold_list = numpy.arange(
0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64
)
# Holds all computed data. Format <split-name: str> -> <split-data: dict>
eval_json_data: dict[str, dict[str, typing.Any]] = {}
# Compute counts for various splits.
for split_name, samples in predict_data.items():
logger.info(
f"Counting true/false positive/negatives at split `{split_name}`..."
)
counts = load_count(predictions.parent, samples, threshold_list)
logger.info(f"Evaluating performance curves/metrics at split `{split_name}`...")
fpr_curve = 1.0 - numpy.array([specificity(*k) for k in counts])
recall_curve = tpr_curve = numpy.array([recall(*k) for k in counts])
precision_curve = numpy.array([precision(*k) for k in counts])
# populates data to be recorded in JSON format
eval_json_data.setdefault(split_name, {})["counts"] = {
k: v for k, v in zip(threshold_list, counts)
}
eval_json_data.setdefault(split_name, {})["auc_score"] = (
credible.curves.area_under_the_curve((fpr_curve, tpr_curve))
)
eval_json_data.setdefault(split_name, {})["average_precision_score"] = (
credible.curves.average_metric((precision_curve, recall_curve))
)
eval_json_data.setdefault(split_name, {})["curves"] = dict(
roc=dict(fpr=fpr_curve, tpr=tpr_curve, thresholds=threshold_list),
precision_recall=dict(
precision=precision_curve,
recall=recall_curve,
thresholds=threshold_list,
),
)
# Computes argmax in the designated split "counts" (typically "validation"),
# where the chosen metric reaches its **maximum**.
if isinstance(threshold, str):
# Compute threshold on specified split, if required
logger.info(f"Evaluating threshold on split `{threshold}` using " f"`{metric}`")
metric_list = compute_metric(
eval_json_data[threshold]["counts"].values(),
name2metric(typing.cast(SUPPORTED_METRIC_TYPE, metric)),
)
threshold_index = metric_list.argmax()
# Reset list of how thresholds are calculated on the recorded split
for split_name in predict_data.keys():
if split_name == threshold:
eval_json_data[split_name]["threshold_a_priori"] = False
else:
eval_json_data[split_name]["threshold_a_posteriori"] = True
else:
# must figure out the closest threshold from the list we are using
threshold_index = (numpy.abs(threshold_list - threshold)).argmin()
# Reset list of how thresholds are calculated on the recorded split
for split_name in predict_data.keys():
eval_json_data[split_name]["threshold_a_priori"] = True
logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")
# Computes all available metrics on the designated threshold, across all
# splits
# Populates <split-name: str> -> <metric-name: SUPPORTED_METRIC_TYPE> ->
# float
metrics_available = list(typing.get_args(SUPPORTED_METRIC_TYPE))
for split_name in predict_data.keys():
logger.info(
f"Computing metrics on split `{split_name}` at "
f"threshold={threshold_list[threshold_index]:.4f}..."
)
base_metrics = all_metrics(
*(list(eval_json_data[split_name]["counts"].values())[threshold_index])
)
eval_json_data[split_name].update(
{k: v for k, v in zip(metrics_available, base_metrics)}
)
return eval_json_data, threshold_list[threshold_index]
def make_table(
eval_data: dict[str, dict[str, typing.Any]], threshold: float, format_: str
) -> str:
"""Extract and format table from pre-computed evaluation data.
Extracts elements from ``eval_data`` that can be displayed on a
terminal-style table, format, and returns it.
Parameters
----------
eval_data
Evaluation data as returned by :py:func:`run`.
threshold
The threshold value used to compute unary metrics on all splits.
format_
A supported tabulate format.
Returns
-------
A string representation of a table.
"""
# Extracts elements from ``eval_json_data`` that can be displayed on a
# terminal-style table, format, and print to screen. Record the table into
# a file for later usage.
metrics_available = list(typing.get_args(SUPPORTED_METRIC_TYPE))
table_headers = ["Dataset", "threshold"] + metrics_available + ["auroc", "avgprec"]
table_data = []
for split_name in eval_data.keys():
base_metrics = [eval_data[split_name][k] for k in metrics_available]
table_data.append(
[split_name, threshold]
+ base_metrics
+ [
eval_data[split_name]["auc_score"],
eval_data[split_name]["average_precision_score"],
]
)
return tabulate.tabulate(
table_data,
table_headers,
tablefmt=format_,
floatfmt=".3f",
stralign="right",
)
def make_plots(eval_data):
"""Create plots for all curves in ``eval_data``.
Parameters
----------
eval_data
Evaluation data as returned by :py:func:`run`.
Returns
-------
A list of figures to record to file
"""
retval = []
with credible.plot.tight_layout(
("False Positive Rate", "True Positive Rate"), "ROC"
) as (fig, ax):
for split_name, data in eval_data.items():
ax.plot(
data["curves"]["roc"]["fpr"],
data["curves"]["roc"]["tpr"],
label=f"{split_name} (AUC: {data['auc_score']:.2f})",
)
ax.legend(loc="best", fancybox=True, framealpha=0.7)
retval.append(fig)
with credible.plot.tight_layout_f1iso(
("Recall", "Precision"), "Precison-Recall"
) as (fig, ax):
for split_name, data in eval_data.items():
ax.plot(
data["curves"]["precision_recall"]["precision"],
data["curves"]["precision_recall"]["recall"],
label=f"{split_name} (AP: {data['average_precision_score']:.2f})",
)
ax.legend(loc="best", fancybox=True, framealpha=0.7)
retval.append(fig)
return retval
# def _compare_annotators_worker(
# baseline_sample: tuple[str, str],
# other_sample: tuple[str, str],
......
......@@ -59,10 +59,10 @@ def view(
return torchvision.transforms.functional.to_pil_image(torch.Tensor(arr))
with h5py.File(basedir / stem, "r") as f:
image: numpy.typing.NDArray[numpy.float32] = numpy.array(f.get("image"))
pred: numpy.typing.NDArray[numpy.float32] = numpy.array(f.get("prediction"))
target: numpy.typing.NDArray[numpy.bool_] = numpy.array(f.get("target"))
mask: numpy.typing.NDArray[numpy.bool_] = numpy.array(f.get("mask"))
image: numpy.typing.NDArray[numpy.float32] = numpy.array(f["image"])
pred: numpy.typing.NDArray[numpy.float32] = numpy.array(f["prediction"])
target: numpy.typing.NDArray[numpy.bool_] = numpy.array(f["target"])
mask: numpy.typing.NDArray[numpy.bool_] = numpy.array(f["mask"])
image *= mask
pred *= mask
......
......@@ -8,7 +8,6 @@ import click
from clapper.click import AliasedGroup
from . import (
# analyze,
config,
database,
dump_annotations,
......
......@@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import json
import pathlib
import typing
......@@ -18,36 +17,6 @@ logger = setup("mednet")
__import__("matplotlib").use("agg")
def validate_threshold(threshold: float | str, splits: list[str]):
"""Validate the user threshold selection and returns parsed threshold.
Parameters
----------
threshold
The threshold to validate.
splits
List of available splits.
Returns
-------
The validated threshold.
"""
try:
# we try to convert it to float first
threshold = float(threshold)
if threshold < 0.0 or threshold > 1.0:
raise ValueError("Float thresholds must be within range [0.0, 1.0]")
except ValueError:
if threshold not in splits:
raise ValueError(
f"Text thresholds should match dataset names, "
f"but {threshold} is not available among the datasets provided ("
f"({', '.join(splits)})"
)
return threshold
@click.command(
entry_point_group="mednet.libs.segmentation.config",
cls=ConfigCommand,
......@@ -160,7 +129,7 @@ def evaluate(
predictions: pathlib.Path,
output_folder: pathlib.Path,
threshold: str | float,
metric: str,
metric: SUPPORTED_METRIC_TYPE,
steps: int,
compare_annotator: pathlib.Path,
plot: bool,
......@@ -168,31 +137,17 @@ def evaluate(
): # numpydoc ignore=PR01
"""Evaluate predictions (from a model) on a segmentation task."""
import credible.curves
import credible.plot
import matplotlib.backends.backend_pdf
import numpy
import tabulate
from mednet.libs.common.scripts.utils import (
save_json_metadata,
save_json_with_backup,
)
from mednet.libs.segmentation.engine.evaluator import (
all_metrics,
compute_metric,
load_count,
name2metric,
precision,
recall,
specificity,
)
from mednet.libs.segmentation.engine.evaluator import make_plots, make_table, run
with predictions.open("r") as f:
predict_data = json.load(f)
evaluation_file = output_folder / "evaluation.json"
# register metadata
save_json_metadata(
output_file=output_folder / "evaluation.meta.json",
output_file=evaluation_file.with_suffix(".meta.json"),
predictions=str(predictions),
output_folder=str(output_folder),
threshold=threshold,
......@@ -202,130 +157,26 @@ def evaluate(
plot=plot,
)
threshold = validate_threshold(threshold, predict_data)
threshold_list = numpy.arange(
0.0, (1.0 + 1 / steps), 1 / steps, dtype=numpy.float64
)
# Compute counts for various splits
eval_json_data: dict[str, dict[str, typing.Any]] = {}
for split_name, samples in predict_data.items():
logger.info(
f"Counting true/false positive/negatives at split `{split_name}`..."
)
eval_json_data.setdefault(split_name, {})["counts"] = {
k: v
for k, v in zip(
threshold_list, load_count(predictions.parent, samples, threshold_list)
)
}
eval_json_data[split_name]["threshold_a_posteriori"] = True
if isinstance(threshold, str):
# Compute threshold on specified split, if required
logger.info(f"Evaluating threshold on split `{threshold}` using " f"`{metric}`")
metric_list = compute_metric(
eval_json_data[threshold]["counts"].values(),
name2metric(typing.cast(SUPPORTED_METRIC_TYPE, metric)),
)
threshold_index = metric_list.argmax()
logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")
eval_json_data, threshold_value = run(predictions, steps, threshold, metric)
# Reset list of how thresholds are calculated on the recorded split
for split_name in predict_data.keys():
if split_name == threshold:
continue
eval_json_data[split_name]["threshold_a_posteriori"] = False
else:
# must figure out the closest threshold from the list we are using
threshold_index = (numpy.abs(threshold_list - threshold)).argmin()
logger.info(f"Set --threshold={threshold_list[threshold_index]:.4f}")
metrics_available = list(typing.get_args(SUPPORTED_METRIC_TYPE))
table_headers = ["Dataset", "threshold"] + metrics_available + ["auroc", "avgprec"]
table_data = []
for split_name in predict_data.keys():
logger.info("Computing performance on split `{split_name}`...")
counts = list(eval_json_data[split_name]["counts"].values())
base_metrics = all_metrics(*counts[threshold_index])
table_data.append([split_name, threshold_list[threshold_index]] + base_metrics)
eval_json_data[split_name].update(
{k: v for k, v in zip(metrics_available, base_metrics)}
)
fpr_curve = 1.0 - numpy.array([specificity(*k) for k in counts])
recall_curve = tpr_curve = numpy.array([recall(*k) for k in counts])
precision_curve = numpy.array([precision(*k) for k in counts])
table_data[-1] += [
credible.curves.area_under_the_curve((fpr_curve, tpr_curve)), # auc-roc
credible.curves.average_metric(
(precision_curve, recall_curve)
), # average precision
]
eval_json_data[split_name]["auc_score"] = table_data[-1][-2]
eval_json_data[split_name]["average_precision_score"] = table_data[-1][-1]
eval_json_data[split_name]["curves"] = dict(
roc=dict(fpr=fpr_curve, tpr=tpr_curve, thresholds=threshold_list),
precision_recall=dict(
precision=precision_curve,
recall=recall_curve,
thresholds=threshold_list,
),
)
# records full result analysis to a JSON file
evaluation_file = output_folder / "evaluation.json"
# Records full result analysis to a JSON file
logger.info(f"Saving evaluation results at `{str(evaluation_file)}`...")
save_json_with_backup(evaluation_file, eval_json_data)
table_format = "rst"
table = tabulate.tabulate(
table_data,
table_headers,
tablefmt=table_format,
floatfmt=".3f",
stralign="right",
)
# Produces and records table
table = make_table(eval_json_data, threshold_value, "rst")
click.echo(table)
output_table = output_folder / "evaluation.rst"
output_table = evaluation_file.with_suffix(".rst")
logger.info(f"Saving tabulated performance summary at `{str(output_table)}`...")
output_table.parent.mkdir(parents=True, exist_ok=True)
with output_table.open("w") as f:
f.write(table)
# Plots pre-calculated curves, if the user asked to do so.
if plot:
figure_path = evaluation_file.with_suffix(".pdf")
logger.info(f"Saving evaluation figures at `{str(figure_path)}`...")
with matplotlib.backends.backend_pdf.PdfPages(figure_path) as pdf:
with credible.plot.tight_layout(
("False Positive Rate", "True Positive Rate"), "ROC"
) as (
fig,
ax,
):
for split_name, data in eval_json_data.items():
ax.plot(
data["curves"]["roc"]["fpr"],
data["curves"]["roc"]["tpr"],
label=f"{split_name} (AUC: {data['auc_score']:.2f})",
)
ax.legend(loc="best", fancybox=True, framealpha=0.7)
pdf.savefig(fig)
with credible.plot.tight_layout_f1iso(
("Recall", "Precision"), "Precison-Recall"
) as (
fig,
ax,
):
for split_name, data in eval_json_data.items():
ax.plot(
data["curves"]["precision_recall"]["precision"],
data["curves"]["precision_recall"]["recall"],
label=f"{split_name} (AP: {data['average_precision_score']:.2f})",
)
ax.legend(loc="best", fancybox=True, framealpha=0.7)
for fig in make_plots(eval_json_data):
pdf.savefig(fig)
......@@ -13,8 +13,6 @@ from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.segmentation.engine.evaluator import SUPPORTED_METRIC_TYPE
from .evaluate import validate_threshold
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -146,6 +144,8 @@ def view(
)
from mednet.libs.segmentation.engine.viewer import view
from ..engine.evaluator import validate_threshold
view_filename = "view.json"
view_file = output_folder / view_filename
......
......@@ -327,7 +327,7 @@ def test_evaluate_lwnet_drive(session_tmp_path):
r"^Writing run metadata at .*$": 1,
r"^Counting true/false positive/negatives at split.*$": 2,
r"^Evaluating threshold on split .*$": 1,
r"^Computing performance on split .*...$": 2,
r"^Computing metrics on split .*...$": 2,
r"^Saving evaluation results at .*$": 1,
r"^Saving tabulated performance summary at .*$": 1,
r"^Saving evaluation figures at .*$": 1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment