Skip to content
Snippets Groups Projects
Commit d0a13cb3 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[evaluation] Homogenize segmentation/classification code; Use more credible...

[evaluation] Homogenize segmentation/classification code; Use more credible constructs (DRY); Fix test units
parent 1c03b05b
No related branches found
No related tags found
1 merge request!46Create common library
Pipeline #89287 passed
......@@ -3,13 +3,13 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines functionality for the evaluation of predictions."""
import contextlib
import itertools
import logging
import typing
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
import credible.curves
import credible.plot
import matplotlib.axes
import matplotlib.figure
import numpy
import numpy.typing
......@@ -119,92 +119,6 @@ def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float:
return maxf1_threshold
def score_plot(
histograms: dict[str, dict[str, numpy.typing.NDArray]],
title: str,
threshold: float | None,
) -> matplotlib.figure.Figure:
"""Plot the normalized score distributions for all systems.
Parameters
----------
histograms
A dictionary containing all histograms that should be inserted into the
plot. Each histogram should itself be setup as another dictionary
containing the keys ``hist`` and ``bin_edges`` as returned by
:py:func:`numpy.histogram`.
title
Title of the plot.
threshold
Shows where the threshold is in the figure. If set to ``None``, then
does not show the threshold line.
Returns
-------
matplotlib.figure.Figure
A single (matplotlib) plot containing the score distribution, ready to
be saved to disk or displayed.
"""
from matplotlib.ticker import MaxNLocator
fig, ax = plt.subplots(1, 1)
assert isinstance(fig, matplotlib.figure.Figure)
ax = typing.cast(plt.Axes, ax) # gets editor to behave
# Here, we configure the "style" of our plot
ax.set_xlim((0, 1))
ax.set_title(title)
ax.set_xlabel("Score")
ax.set_ylabel("Count")
# Only show ticks on the left and bottom spines
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.get_yaxis().set_major_locator(MaxNLocator(integer=True))
# Setup the grid
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
ax.get_xaxis().grid(False)
max_hist = 0
for name in histograms.keys():
hist = histograms[name]["hist"]
bin_edges = histograms[name]["bin_edges"]
width = 0.7 * (bin_edges[1] - bin_edges[0])
center = (bin_edges[:-1] + bin_edges[1:]) / 2
ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7)
max_hist = max(max_hist, hist.max())
# Detach axes from the plot
ax.spines["left"].set_position(("data", -0.015))
ax.spines["bottom"].set_position(("data", -0.015 * max_hist))
if threshold is not None:
# Adds threshold line (dotted red)
ax.axvline(
threshold, # type: ignore
color="red",
lw=2,
alpha=0.75,
ls="dotted",
label="threshold",
)
# Adds a nice legend
ax.legend(
fancybox=True,
framealpha=0.7,
)
# Makes sure the figure occupies most of the possible space
fig.tight_layout()
return fig
def run_binary(
name: str,
predictions: Iterable[BinaryPrediction],
......@@ -345,7 +259,7 @@ def run_binary(
return summary
def tabulate_results(
def make_table(
data: typing.Mapping[str, typing.Mapping[str, typing.Any]],
fmt: str,
) -> str:
......@@ -368,243 +282,162 @@ def tabulate_results(
A string containing the tabulated information.
"""
example = next(iter(data.values()))
# dump evaluation results in RST format to screen and file
table_data = {}
for k, v in data.items():
table_data[k] = {
kk: vv for kk, vv in v.items() if kk not in ("curves", "score-histograms")
}
example = next(iter(table_data.values()))
headers = list(example.keys())
table = [[k[h] for h in headers] for k in data.values()]
table = [[k[h] for h in headers] for k in table_data.values()]
# add subset names
headers = ["subset"] + headers
table = [[name] + k for name, k in zip(data.keys(), table)]
table = [[name] + k for name, k in zip(table_data.keys(), table)]
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
def aggregate_roc(
data: typing.Mapping[str, typing.Any],
title: str = "ROC",
def _score_plot(
histograms: dict[str, dict[str, numpy.typing.NDArray]],
title: str,
threshold: float | None,
) -> matplotlib.figure.Figure:
"""Aggregate ROC curves from multiple splits.
This function produces a single ROC plot for multiple curves generated per
split.
"""Plot the normalized score distributions for all systems.
Parameters
----------
data
A dictionary mapping split names to ROC curve data produced by
:py:func:sklearn.metrics.roc_curve`.
histograms
A dictionary containing all histograms that should be inserted into the
plot. Each histogram should itself be setup as another dictionary
containing the keys ``hist`` and ``bin_edges`` as returned by
:py:func:`numpy.histogram`.
title
The title of the plot.
Title of the plot.
threshold
Shows where the threshold is in the figure. If set to ``None``, then
does not show the threshold line.
Returns
-------
matplotlib.figure.Figure
A figure, containing the aggregated ROC plot.
A single (matplotlib) plot containing the score distribution, ready to
be saved to disk or displayed.
"""
from matplotlib.ticker import MaxNLocator
fig, ax = plt.subplots(1, 1)
assert isinstance(fig, matplotlib.figure.Figure)
ax = typing.cast(matplotlib.axes.Axes, ax) # gets editor to behave
# Names and bounds
ax.set_xlabel("1 - specificity")
ax.set_ylabel("Sensitivity")
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.0])
# Here, we configure the "style" of our plot
ax.set_xlim((0, 1))
ax.set_title(title)
ax.set_xlabel("Score")
ax.set_ylabel("Count")
# we should see some of ax 1 ax
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_position(("data", -0.015))
ax.spines["bottom"].set_position(("data", -0.015))
# Only show ticks on the left and bottom spines
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.get_yaxis().set_major_locator(MaxNLocator(integer=True))
# Setup the grid
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
ax.get_xaxis().grid(False)
plt.tight_layout()
lines = ["-", "--", "-.", ":"]
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colorcycler = itertools.cycle(colors)
linecycler = itertools.cycle(lines)
legend = []
for name, elements in data.items():
# plots roc curve
_auc = sklearn.metrics.auc(elements["fpr"], elements["tpr"])
label = f"{name} (AUC={_auc:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = ax.plot(
elements["fpr"],
elements["tpr"],
color=color,
linestyle=style,
)
legend.append((line, label))
if len(legend) > 1:
ax.legend(
[k[0] for k in legend],
[k[1] for k in legend],
loc="lower right",
fancybox=True,
framealpha=0.7,
)
max_hist = 0
for name in histograms.keys():
hist = histograms[name]["hist"]
bin_edges = histograms[name]["bin_edges"]
width = 0.7 * (bin_edges[1] - bin_edges[0])
center = (bin_edges[:-1] + bin_edges[1:]) / 2
ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7)
max_hist = max(max_hist, hist.max())
return fig
# Detach axes from the plot
ax.spines["left"].set_position(("data", -0.015))
ax.spines["bottom"].set_position(("data", -0.015 * max_hist))
if threshold is not None:
# Adds threshold line (dotted red)
ax.axvline(
threshold, # type: ignore
color="red",
lw=2,
alpha=0.75,
ls="dotted",
label="threshold",
)
@contextlib.contextmanager
def _precision_recall_canvas() -> (
Iterator[tuple[matplotlib.figure.Figure, matplotlib.figure.Axes]]
):
"""Generate a canvas to draw precision-recall curves.
# Adds a nice legend
ax.legend(
fancybox=True,
framealpha=0.7,
)
Works like a context manager, yielding a figure and an axes set in which
the precision-recall curves should be added to. The figure already
contains F1-ISO lines and is preset to a 0-1 square region. Once the
context is finished, ``fig.tight_layout()`` is called.
# Makes sure the figure occupies most of the possible space
fig.tight_layout()
Yields
------
figure
The figure that should be finally returned to the user.
axes
An axis set where to precision-recall plots should be added to.
"""
return fig
fig, axes1 = plt.subplots(1)
assert isinstance(fig, matplotlib.figure.Figure)
assert isinstance(axes1, matplotlib.figure.Axes)
# Names and bounds
axes1.set_xlabel("Recall")
axes1.set_ylabel("Precision")
axes1.set_xlim([0.0, 1.0])
axes1.set_ylim([0.0, 1.0])
axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
axes2 = axes1.twinx()
# Annotates plot with F1-score iso-lines
f_scores = numpy.linspace(0.1, 0.9, num=9)
tick_locs = []
tick_labels = []
for f_score in f_scores:
x = numpy.linspace(0.01, 1)
y = f_score * x / (2 * x - f_score)
plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
tick_locs.append(y[-1])
tick_labels.append(f"{f_score:.1f}")
axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
axes2.set_ylabel("iso-F", color="green", alpha=0.3)
axes2.set_ylim([0.0, 1.0])
axes2.yaxis.set_label_coords(1.015, 0.97)
axes2.set_yticks(tick_locs) # notice these are invisible
for k in axes2.set_yticklabels(tick_labels):
k.set_color("green")
k.set_alpha(0.3)
k.set_size(8)
# we should see some of axes 1 axes
axes1.spines["right"].set_visible(False)
axes1.spines["top"].set_visible(False)
axes1.spines["left"].set_position(("data", -0.015))
axes1.spines["bottom"].set_position(("data", -0.015))
# we shouldn't see any of axes 2 axes
axes2.spines["right"].set_visible(False)
axes2.spines["top"].set_visible(False)
axes2.spines["left"].set_visible(False)
axes2.spines["bottom"].set_visible(False)
# yield execution, lets user draw precision-recall plots, and the legend
# before tighteneing the layout
yield fig, axes1
plt.tight_layout()
def aggregate_pr(
data: typing.Mapping[str, typing.Any],
title: str = "Precision-Recall Curve",
) -> matplotlib.figure.Figure:
"""Aggregate PR curves from multiple splits.
This function produces a single Precision-Recall plot for multiple curves
generated per split. The plot will be annotated with F1-score iso-lines (in
which the F1-score maintains the same value).
def make_plots(results: dict[str, dict[str, typing.Any]]) -> list:
"""Create plots for all curves and score distributions in ``results``.
Parameters
----------
data
A dictionary mapping split names to Precision-Recall curve data produced by
:py:func:sklearn.metrics.precision_recall_curve`.
title
The title of the plot.
results
Evaluation data as returned by :py:func:`run_binary`.
Returns
-------
matplotlib.figure.Figure
A figure, containing the aggregated PR plot.
A list of figures to record to file
"""
lines = ["-", "--", "-.", ":"]
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colorcycler = itertools.cycle(colors)
linecycler = itertools.cycle(lines)
with _precision_recall_canvas() as (fig, axes):
axes.set_title(title)
legend = []
for name, elements in data.items():
_ap = credible.curves.average_metric(
(elements["precision"], elements["recall"]),
retval = []
with credible.plot.tight_layout(
("False Positive Rate", "True Positive Rate"), "ROC"
) as (fig, ax):
for split_name, data in results.items():
_auroc = credible.curves.area_under_the_curve(
(data["curves"]["roc"]["fpr"], data["curves"]["roc"]["tpr"]),
)
label = f"{name} (AP={_ap:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = axes.plot(
elements["recall"],
elements["precision"],
color=color,
linestyle=style,
ax.plot(
data["curves"]["roc"]["fpr"],
data["curves"]["roc"]["tpr"],
label=f"{split_name} (AUC: {_auroc:.2f})",
)
legend.append((line, label))
if len(legend) > 1:
axes.legend(
[k[0] for k in legend],
[k[1] for k in legend],
loc="lower left",
fancybox=True,
framealpha=0.7,
ax.legend(loc="best", fancybox=True, framealpha=0.7)
retval.append(fig)
with credible.plot.tight_layout_f1iso(
("Recall", "Precision"), "Precison-Recall"
) as (fig, ax):
for split_name, data in results.items():
_ap = credible.curves.average_metric(
(data["precision"], data["recall"]),
)
ax.plot(
data["curves"]["precision_recall"]["recall"],
data["curves"]["precision_recall"]["precision"],
label=f"{split_name} (AP: {_ap:.2f})",
)
ax.legend(loc="best", fancybox=True, framealpha=0.7)
retval.append(fig)
# score plots
for split_name, data in results.items():
score_fig = _score_plot(
data["score-histograms"],
f"Score distribution (split: {split_name})",
data["threshold"],
)
retval.append(score_fig)
return fig
return retval
......@@ -114,18 +114,16 @@ def evaluate(
import json
import typing
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.backends.backend_pdf
from mednet.libs.common.scripts.utils import (
save_json_metadata,
save_json_with_backup,
)
from ..engine.evaluator import (
aggregate_pr,
aggregate_roc,
make_plots,
make_table,
run_binary,
score_plot,
tabulate_results,
)
evaluation_file = output_folder / "evaluation.json"
......@@ -176,40 +174,20 @@ def evaluate(
logger.info(f"Saving evaluation results at `{str(evaluation_file)}`...")
save_json_with_backup(evaluation_file, results)
# dump evaluation results in RST format to screen and file
table_data = {}
for k, v in results.items():
table_data[k] = {
kk: vv for kk, vv in v.items() if kk not in ("curves", "score-histograms")
}
table = tabulate_results(table_data, fmt="rst")
# Produces and records table
table = make_table(results, "rst")
click.echo(table)
table_path = evaluation_file.with_suffix(".rst")
logger.info(
f"Saving evaluation results in table format at `{table_path}`...",
)
with table_path.open("w") as f:
output_table = evaluation_file.with_suffix(".rst")
logger.info(f"Saving tabulated performance summary at `{str(output_table)}`...")
output_table.parent.mkdir(parents=True, exist_ok=True)
with output_table.open("w") as f:
f.write(table)
# Plots pre-calculated curves, if the user asked to do so.
if plot:
figure_path = evaluation_file.with_suffix(".pdf")
logger.info(f"Saving evaluation figures at `{str(figure_path)}`...")
with PdfPages(figure_path) as pdf:
pr_curves = {k: v["curves"]["precision_recall"] for k, v in results.items()}
pr_fig = aggregate_pr(pr_curves)
pdf.savefig(pr_fig)
roc_curves = {k: v["curves"]["roc"] for k, v in results.items()}
roc_fig = aggregate_roc(roc_curves)
pdf.savefig(roc_fig)
# score plots
for k, v in results.items():
score_fig = score_plot(
v["score-histograms"],
f"Score distribution (split: {k})",
v["threshold"],
)
pdf.savefig(score_fig)
with matplotlib.backends.backend_pdf.PdfPages(figure_path) as pdf:
for fig in make_plots(results):
pdf.savefig(fig)
......@@ -739,7 +739,7 @@ def make_table(
# terminal-style table, format, and print to screen. Record the table into
# a file for later usage.
metrics_available = list(typing.get_args(SUPPORTED_METRIC_TYPE))
table_headers = ["Dataset", "threshold"] + metrics_available + ["auroc", "avgprec"]
table_headers = ["subset", "threshold"] + metrics_available + ["auroc", "avgprec"]
table_data = []
for split_name, data in eval_data.items():
......@@ -766,7 +766,7 @@ def make_table(
)
def make_plots(eval_data):
def make_plots(eval_data: dict[str, dict[str, typing.Any]]) -> list:
"""Create plots for all curves in ``eval_data``.
Parameters
......@@ -810,17 +810,17 @@ def make_plots(eval_data):
) as (fig, ax):
for split_name, data in eval_data.items():
ax.plot(
data["curves"]["precision_recall"]["precision"],
data["curves"]["precision_recall"]["recall"],
data["curves"]["precision_recall"]["precision"],
label=f"{split_name} (AP: {data['average_precision_score']:.2f})",
)
if "second_annotator" in data:
precision = data["second_annotator"]["precision"]
recall = data["second_annotator"]["recall"]
precision = data["second_annotator"]["precision"]
ax.plot(
precision,
recall,
precision,
linestyle="none",
marker="*",
markersize=8,
......
......@@ -322,7 +322,7 @@ def test_evaluate_pasa_montgomery(session_tmp_path):
r"^Setting --threshold=.*$": 1,
r"^Computing performance on split .*...$": 3,
r"^Saving evaluation results at .*$": 1,
r"^Saving evaluation results in table format at .*$": 1,
r"^Saving tabulated performance summary at .*$": 1,
r"^Saving evaluation figures at .*$": 1,
}
buf.seek(0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment