Skip to content
Snippets Groups Projects
Commit d0a13cb3 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[evaluation] Homogenize segmentation/classification code; Use more credible...

[evaluation] Homogenize segmentation/classification code; Use more credible constructs (DRY); Fix test units
parent 1c03b05b
No related branches found
No related tags found
1 merge request!46Create common library
Pipeline #89287 passed
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Defines functionality for the evaluation of predictions.""" """Defines functionality for the evaluation of predictions."""
import contextlib
import itertools
import logging import logging
import typing import typing
from collections.abc import Iterable, Iterator from collections.abc import Iterable
import credible.curves import credible.curves
import credible.plot
import matplotlib.axes
import matplotlib.figure import matplotlib.figure
import numpy import numpy
import numpy.typing import numpy.typing
...@@ -119,92 +119,6 @@ def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float: ...@@ -119,92 +119,6 @@ def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float:
return maxf1_threshold return maxf1_threshold
def score_plot(
histograms: dict[str, dict[str, numpy.typing.NDArray]],
title: str,
threshold: float | None,
) -> matplotlib.figure.Figure:
"""Plot the normalized score distributions for all systems.
Parameters
----------
histograms
A dictionary containing all histograms that should be inserted into the
plot. Each histogram should itself be setup as another dictionary
containing the keys ``hist`` and ``bin_edges`` as returned by
:py:func:`numpy.histogram`.
title
Title of the plot.
threshold
Shows where the threshold is in the figure. If set to ``None``, then
does not show the threshold line.
Returns
-------
matplotlib.figure.Figure
A single (matplotlib) plot containing the score distribution, ready to
be saved to disk or displayed.
"""
from matplotlib.ticker import MaxNLocator
fig, ax = plt.subplots(1, 1)
assert isinstance(fig, matplotlib.figure.Figure)
ax = typing.cast(plt.Axes, ax) # gets editor to behave
# Here, we configure the "style" of our plot
ax.set_xlim((0, 1))
ax.set_title(title)
ax.set_xlabel("Score")
ax.set_ylabel("Count")
# Only show ticks on the left and bottom spines
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.get_yaxis().set_major_locator(MaxNLocator(integer=True))
# Setup the grid
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
ax.get_xaxis().grid(False)
max_hist = 0
for name in histograms.keys():
hist = histograms[name]["hist"]
bin_edges = histograms[name]["bin_edges"]
width = 0.7 * (bin_edges[1] - bin_edges[0])
center = (bin_edges[:-1] + bin_edges[1:]) / 2
ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7)
max_hist = max(max_hist, hist.max())
# Detach axes from the plot
ax.spines["left"].set_position(("data", -0.015))
ax.spines["bottom"].set_position(("data", -0.015 * max_hist))
if threshold is not None:
# Adds threshold line (dotted red)
ax.axvline(
threshold, # type: ignore
color="red",
lw=2,
alpha=0.75,
ls="dotted",
label="threshold",
)
# Adds a nice legend
ax.legend(
fancybox=True,
framealpha=0.7,
)
# Makes sure the figure occupies most of the possible space
fig.tight_layout()
return fig
def run_binary( def run_binary(
name: str, name: str,
predictions: Iterable[BinaryPrediction], predictions: Iterable[BinaryPrediction],
...@@ -345,7 +259,7 @@ def run_binary( ...@@ -345,7 +259,7 @@ def run_binary(
return summary return summary
def tabulate_results( def make_table(
data: typing.Mapping[str, typing.Mapping[str, typing.Any]], data: typing.Mapping[str, typing.Mapping[str, typing.Any]],
fmt: str, fmt: str,
) -> str: ) -> str:
...@@ -368,243 +282,162 @@ def tabulate_results( ...@@ -368,243 +282,162 @@ def tabulate_results(
A string containing the tabulated information. A string containing the tabulated information.
""" """
example = next(iter(data.values())) # dump evaluation results in RST format to screen and file
table_data = {}
for k, v in data.items():
table_data[k] = {
kk: vv for kk, vv in v.items() if kk not in ("curves", "score-histograms")
}
example = next(iter(table_data.values()))
headers = list(example.keys()) headers = list(example.keys())
table = [[k[h] for h in headers] for k in data.values()] table = [[k[h] for h in headers] for k in table_data.values()]
# add subset names # add subset names
headers = ["subset"] + headers headers = ["subset"] + headers
table = [[name] + k for name, k in zip(data.keys(), table)] table = [[name] + k for name, k in zip(table_data.keys(), table)]
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f") return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
def aggregate_roc( def _score_plot(
data: typing.Mapping[str, typing.Any], histograms: dict[str, dict[str, numpy.typing.NDArray]],
title: str = "ROC", title: str,
threshold: float | None,
) -> matplotlib.figure.Figure: ) -> matplotlib.figure.Figure:
"""Aggregate ROC curves from multiple splits. """Plot the normalized score distributions for all systems.
This function produces a single ROC plot for multiple curves generated per
split.
Parameters Parameters
---------- ----------
data histograms
A dictionary mapping split names to ROC curve data produced by A dictionary containing all histograms that should be inserted into the
:py:func:sklearn.metrics.roc_curve`. plot. Each histogram should itself be setup as another dictionary
containing the keys ``hist`` and ``bin_edges`` as returned by
:py:func:`numpy.histogram`.
title title
The title of the plot. Title of the plot.
threshold
Shows where the threshold is in the figure. If set to ``None``, then
does not show the threshold line.
Returns Returns
------- -------
matplotlib.figure.Figure matplotlib.figure.Figure
A figure, containing the aggregated ROC plot. A single (matplotlib) plot containing the score distribution, ready to
be saved to disk or displayed.
""" """
from matplotlib.ticker import MaxNLocator
fig, ax = plt.subplots(1, 1) fig, ax = plt.subplots(1, 1)
assert isinstance(fig, matplotlib.figure.Figure) assert isinstance(fig, matplotlib.figure.Figure)
ax = typing.cast(matplotlib.axes.Axes, ax) # gets editor to behave
# Names and bounds # Here, we configure the "style" of our plot
ax.set_xlabel("1 - specificity") ax.set_xlim((0, 1))
ax.set_ylabel("Sensitivity")
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.0])
ax.set_title(title) ax.set_title(title)
ax.set_xlabel("Score")
ax.set_ylabel("Count")
# we should see some of ax 1 ax # Only show ticks on the left and bottom spines
ax.spines["right"].set_visible(False) ax.spines.right.set_visible(False)
ax.spines["top"].set_visible(False) ax.spines.top.set_visible(False)
ax.spines["left"].set_position(("data", -0.015)) ax.get_xaxis().tick_bottom()
ax.spines["bottom"].set_position(("data", -0.015)) ax.get_yaxis().tick_left()
ax.get_yaxis().set_major_locator(MaxNLocator(integer=True))
# Setup the grid
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
ax.get_xaxis().grid(False)
plt.tight_layout() max_hist = 0
for name in histograms.keys():
lines = ["-", "--", "-.", ":"] hist = histograms[name]["hist"]
colors = [ bin_edges = histograms[name]["bin_edges"]
"#1f77b4", width = 0.7 * (bin_edges[1] - bin_edges[0])
"#ff7f0e", center = (bin_edges[:-1] + bin_edges[1:]) / 2
"#2ca02c", ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7)
"#d62728", max_hist = max(max_hist, hist.max())
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colorcycler = itertools.cycle(colors)
linecycler = itertools.cycle(lines)
legend = []
for name, elements in data.items():
# plots roc curve
_auc = sklearn.metrics.auc(elements["fpr"], elements["tpr"])
label = f"{name} (AUC={_auc:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = ax.plot(
elements["fpr"],
elements["tpr"],
color=color,
linestyle=style,
)
legend.append((line, label))
if len(legend) > 1:
ax.legend(
[k[0] for k in legend],
[k[1] for k in legend],
loc="lower right",
fancybox=True,
framealpha=0.7,
)
return fig # Detach axes from the plot
ax.spines["left"].set_position(("data", -0.015))
ax.spines["bottom"].set_position(("data", -0.015 * max_hist))
if threshold is not None:
# Adds threshold line (dotted red)
ax.axvline(
threshold, # type: ignore
color="red",
lw=2,
alpha=0.75,
ls="dotted",
label="threshold",
)
@contextlib.contextmanager # Adds a nice legend
def _precision_recall_canvas() -> ( ax.legend(
Iterator[tuple[matplotlib.figure.Figure, matplotlib.figure.Axes]] fancybox=True,
): framealpha=0.7,
"""Generate a canvas to draw precision-recall curves. )
Works like a context manager, yielding a figure and an axes set in which # Makes sure the figure occupies most of the possible space
the precision-recall curves should be added to. The figure already fig.tight_layout()
contains F1-ISO lines and is preset to a 0-1 square region. Once the
context is finished, ``fig.tight_layout()`` is called.
Yields return fig
------
figure
The figure that should be finally returned to the user.
axes
An axis set where to precision-recall plots should be added to.
"""
fig, axes1 = plt.subplots(1)
assert isinstance(fig, matplotlib.figure.Figure)
assert isinstance(axes1, matplotlib.figure.Axes)
# Names and bounds
axes1.set_xlabel("Recall")
axes1.set_ylabel("Precision")
axes1.set_xlim([0.0, 1.0])
axes1.set_ylim([0.0, 1.0])
axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
axes2 = axes1.twinx()
# Annotates plot with F1-score iso-lines
f_scores = numpy.linspace(0.1, 0.9, num=9)
tick_locs = []
tick_labels = []
for f_score in f_scores:
x = numpy.linspace(0.01, 1)
y = f_score * x / (2 * x - f_score)
plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
tick_locs.append(y[-1])
tick_labels.append(f"{f_score:.1f}")
axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
axes2.set_ylabel("iso-F", color="green", alpha=0.3)
axes2.set_ylim([0.0, 1.0])
axes2.yaxis.set_label_coords(1.015, 0.97)
axes2.set_yticks(tick_locs) # notice these are invisible
for k in axes2.set_yticklabels(tick_labels):
k.set_color("green")
k.set_alpha(0.3)
k.set_size(8)
# we should see some of axes 1 axes
axes1.spines["right"].set_visible(False)
axes1.spines["top"].set_visible(False)
axes1.spines["left"].set_position(("data", -0.015))
axes1.spines["bottom"].set_position(("data", -0.015))
# we shouldn't see any of axes 2 axes
axes2.spines["right"].set_visible(False)
axes2.spines["top"].set_visible(False)
axes2.spines["left"].set_visible(False)
axes2.spines["bottom"].set_visible(False)
# yield execution, lets user draw precision-recall plots, and the legend
# before tighteneing the layout
yield fig, axes1
plt.tight_layout()
def aggregate_pr(
data: typing.Mapping[str, typing.Any],
title: str = "Precision-Recall Curve",
) -> matplotlib.figure.Figure:
"""Aggregate PR curves from multiple splits.
This function produces a single Precision-Recall plot for multiple curves def make_plots(results: dict[str, dict[str, typing.Any]]) -> list:
generated per split. The plot will be annotated with F1-score iso-lines (in """Create plots for all curves and score distributions in ``results``.
which the F1-score maintains the same value).
Parameters Parameters
---------- ----------
data results
A dictionary mapping split names to Precision-Recall curve data produced by Evaluation data as returned by :py:func:`run_binary`.
:py:func:sklearn.metrics.precision_recall_curve`.
title
The title of the plot.
Returns Returns
------- -------
matplotlib.figure.Figure A list of figures to record to file
A figure, containing the aggregated PR plot.
""" """
lines = ["-", "--", "-.", ":"] retval = []
colors = [
"#1f77b4", with credible.plot.tight_layout(
"#ff7f0e", ("False Positive Rate", "True Positive Rate"), "ROC"
"#2ca02c", ) as (fig, ax):
"#d62728", for split_name, data in results.items():
"#9467bd", _auroc = credible.curves.area_under_the_curve(
"#8c564b", (data["curves"]["roc"]["fpr"], data["curves"]["roc"]["tpr"]),
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colorcycler = itertools.cycle(colors)
linecycler = itertools.cycle(lines)
with _precision_recall_canvas() as (fig, axes):
axes.set_title(title)
legend = []
for name, elements in data.items():
_ap = credible.curves.average_metric(
(elements["precision"], elements["recall"]),
) )
label = f"{name} (AP={_ap:.2f})" ax.plot(
color = next(colorcycler) data["curves"]["roc"]["fpr"],
style = next(linecycler) data["curves"]["roc"]["tpr"],
label=f"{split_name} (AUC: {_auroc:.2f})",
(line,) = axes.plot(
elements["recall"],
elements["precision"],
color=color,
linestyle=style,
) )
legend.append((line, label)) ax.legend(loc="best", fancybox=True, framealpha=0.7)
retval.append(fig)
if len(legend) > 1:
axes.legend( with credible.plot.tight_layout_f1iso(
[k[0] for k in legend], ("Recall", "Precision"), "Precison-Recall"
[k[1] for k in legend], ) as (fig, ax):
loc="lower left", for split_name, data in results.items():
fancybox=True, _ap = credible.curves.average_metric(
framealpha=0.7, (data["precision"], data["recall"]),
)
ax.plot(
data["curves"]["precision_recall"]["recall"],
data["curves"]["precision_recall"]["precision"],
label=f"{split_name} (AP: {_ap:.2f})",
) )
ax.legend(loc="best", fancybox=True, framealpha=0.7)
retval.append(fig)
# score plots
for split_name, data in results.items():
score_fig = _score_plot(
data["score-histograms"],
f"Score distribution (split: {split_name})",
data["threshold"],
)
retval.append(score_fig)
return fig return retval
...@@ -114,18 +114,16 @@ def evaluate( ...@@ -114,18 +114,16 @@ def evaluate(
import json import json
import typing import typing
from matplotlib.backends.backend_pdf import PdfPages import matplotlib.backends.backend_pdf
from mednet.libs.common.scripts.utils import ( from mednet.libs.common.scripts.utils import (
save_json_metadata, save_json_metadata,
save_json_with_backup, save_json_with_backup,
) )
from ..engine.evaluator import ( from ..engine.evaluator import (
aggregate_pr, make_plots,
aggregate_roc, make_table,
run_binary, run_binary,
score_plot,
tabulate_results,
) )
evaluation_file = output_folder / "evaluation.json" evaluation_file = output_folder / "evaluation.json"
...@@ -176,40 +174,20 @@ def evaluate( ...@@ -176,40 +174,20 @@ def evaluate(
logger.info(f"Saving evaluation results at `{str(evaluation_file)}`...") logger.info(f"Saving evaluation results at `{str(evaluation_file)}`...")
save_json_with_backup(evaluation_file, results) save_json_with_backup(evaluation_file, results)
# dump evaluation results in RST format to screen and file # Produces and records table
table_data = {} table = make_table(results, "rst")
for k, v in results.items():
table_data[k] = {
kk: vv for kk, vv in v.items() if kk not in ("curves", "score-histograms")
}
table = tabulate_results(table_data, fmt="rst")
click.echo(table) click.echo(table)
table_path = evaluation_file.with_suffix(".rst") output_table = evaluation_file.with_suffix(".rst")
logger.info( logger.info(f"Saving tabulated performance summary at `{str(output_table)}`...")
f"Saving evaluation results in table format at `{table_path}`...", output_table.parent.mkdir(parents=True, exist_ok=True)
) with output_table.open("w") as f:
with table_path.open("w") as f:
f.write(table) f.write(table)
# Plots pre-calculated curves, if the user asked to do so.
if plot: if plot:
figure_path = evaluation_file.with_suffix(".pdf") figure_path = evaluation_file.with_suffix(".pdf")
logger.info(f"Saving evaluation figures at `{str(figure_path)}`...") logger.info(f"Saving evaluation figures at `{str(figure_path)}`...")
with matplotlib.backends.backend_pdf.PdfPages(figure_path) as pdf:
with PdfPages(figure_path) as pdf: for fig in make_plots(results):
pr_curves = {k: v["curves"]["precision_recall"] for k, v in results.items()} pdf.savefig(fig)
pr_fig = aggregate_pr(pr_curves)
pdf.savefig(pr_fig)
roc_curves = {k: v["curves"]["roc"] for k, v in results.items()}
roc_fig = aggregate_roc(roc_curves)
pdf.savefig(roc_fig)
# score plots
for k, v in results.items():
score_fig = score_plot(
v["score-histograms"],
f"Score distribution (split: {k})",
v["threshold"],
)
pdf.savefig(score_fig)
...@@ -739,7 +739,7 @@ def make_table( ...@@ -739,7 +739,7 @@ def make_table(
# terminal-style table, format, and print to screen. Record the table into # terminal-style table, format, and print to screen. Record the table into
# a file for later usage. # a file for later usage.
metrics_available = list(typing.get_args(SUPPORTED_METRIC_TYPE)) metrics_available = list(typing.get_args(SUPPORTED_METRIC_TYPE))
table_headers = ["Dataset", "threshold"] + metrics_available + ["auroc", "avgprec"] table_headers = ["subset", "threshold"] + metrics_available + ["auroc", "avgprec"]
table_data = [] table_data = []
for split_name, data in eval_data.items(): for split_name, data in eval_data.items():
...@@ -766,7 +766,7 @@ def make_table( ...@@ -766,7 +766,7 @@ def make_table(
) )
def make_plots(eval_data): def make_plots(eval_data: dict[str, dict[str, typing.Any]]) -> list:
"""Create plots for all curves in ``eval_data``. """Create plots for all curves in ``eval_data``.
Parameters Parameters
...@@ -810,17 +810,17 @@ def make_plots(eval_data): ...@@ -810,17 +810,17 @@ def make_plots(eval_data):
) as (fig, ax): ) as (fig, ax):
for split_name, data in eval_data.items(): for split_name, data in eval_data.items():
ax.plot( ax.plot(
data["curves"]["precision_recall"]["precision"],
data["curves"]["precision_recall"]["recall"], data["curves"]["precision_recall"]["recall"],
data["curves"]["precision_recall"]["precision"],
label=f"{split_name} (AP: {data['average_precision_score']:.2f})", label=f"{split_name} (AP: {data['average_precision_score']:.2f})",
) )
if "second_annotator" in data: if "second_annotator" in data:
precision = data["second_annotator"]["precision"]
recall = data["second_annotator"]["recall"] recall = data["second_annotator"]["recall"]
precision = data["second_annotator"]["precision"]
ax.plot( ax.plot(
precision,
recall, recall,
precision,
linestyle="none", linestyle="none",
marker="*", marker="*",
markersize=8, markersize=8,
......
...@@ -322,7 +322,7 @@ def test_evaluate_pasa_montgomery(session_tmp_path): ...@@ -322,7 +322,7 @@ def test_evaluate_pasa_montgomery(session_tmp_path):
r"^Setting --threshold=.*$": 1, r"^Setting --threshold=.*$": 1,
r"^Computing performance on split .*...$": 3, r"^Computing performance on split .*...$": 3,
r"^Saving evaluation results at .*$": 1, r"^Saving evaluation results at .*$": 1,
r"^Saving evaluation results in table format at .*$": 1, r"^Saving tabulated performance summary at .*$": 1,
r"^Saving evaluation figures at .*$": 1, r"^Saving evaluation figures at .*$": 1,
} }
buf.seek(0) buf.seek(0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment