diff --git a/src/mednet/libs/classification/engine/evaluator.py b/src/mednet/libs/classification/engine/evaluator.py index b79f2b14d73aa826d24713613a9feca7b059fd3f..d0af84bc1ba57b80d1ff1cc1f114800dc84a6155 100644 --- a/src/mednet/libs/classification/engine/evaluator.py +++ b/src/mednet/libs/classification/engine/evaluator.py @@ -3,13 +3,13 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Defines functionality for the evaluation of predictions.""" -import contextlib -import itertools import logging import typing -from collections.abc import Iterable, Iterator +from collections.abc import Iterable import credible.curves +import credible.plot +import matplotlib.axes import matplotlib.figure import numpy import numpy.typing @@ -119,92 +119,6 @@ def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float: return maxf1_threshold -def score_plot( - histograms: dict[str, dict[str, numpy.typing.NDArray]], - title: str, - threshold: float | None, -) -> matplotlib.figure.Figure: - """Plot the normalized score distributions for all systems. - - Parameters - ---------- - histograms - A dictionary containing all histograms that should be inserted into the - plot. Each histogram should itself be setup as another dictionary - containing the keys ``hist`` and ``bin_edges`` as returned by - :py:func:`numpy.histogram`. - title - Title of the plot. - threshold - Shows where the threshold is in the figure. If set to ``None``, then - does not show the threshold line. - - Returns - ------- - matplotlib.figure.Figure - A single (matplotlib) plot containing the score distribution, ready to - be saved to disk or displayed. - """ - - from matplotlib.ticker import MaxNLocator - - fig, ax = plt.subplots(1, 1) - assert isinstance(fig, matplotlib.figure.Figure) - ax = typing.cast(plt.Axes, ax) # gets editor to behave - - # Here, we configure the "style" of our plot - ax.set_xlim((0, 1)) - ax.set_title(title) - ax.set_xlabel("Score") - ax.set_ylabel("Count") - - # Only show ticks on the left and bottom spines - ax.spines.right.set_visible(False) - ax.spines.top.set_visible(False) - ax.get_xaxis().tick_bottom() - ax.get_yaxis().tick_left() - ax.get_yaxis().set_major_locator(MaxNLocator(integer=True)) - - # Setup the grid - ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) - ax.get_xaxis().grid(False) - - max_hist = 0 - for name in histograms.keys(): - hist = histograms[name]["hist"] - bin_edges = histograms[name]["bin_edges"] - width = 0.7 * (bin_edges[1] - bin_edges[0]) - center = (bin_edges[:-1] + bin_edges[1:]) / 2 - ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7) - max_hist = max(max_hist, hist.max()) - - # Detach axes from the plot - ax.spines["left"].set_position(("data", -0.015)) - ax.spines["bottom"].set_position(("data", -0.015 * max_hist)) - - if threshold is not None: - # Adds threshold line (dotted red) - ax.axvline( - threshold, # type: ignore - color="red", - lw=2, - alpha=0.75, - ls="dotted", - label="threshold", - ) - - # Adds a nice legend - ax.legend( - fancybox=True, - framealpha=0.7, - ) - - # Makes sure the figure occupies most of the possible space - fig.tight_layout() - - return fig - - def run_binary( name: str, predictions: Iterable[BinaryPrediction], @@ -345,7 +259,7 @@ def run_binary( return summary -def tabulate_results( +def make_table( data: typing.Mapping[str, typing.Mapping[str, typing.Any]], fmt: str, ) -> str: @@ -368,243 +282,162 @@ def tabulate_results( A string containing the tabulated information. """ - example = next(iter(data.values())) + # dump evaluation results in RST format to screen and file + table_data = {} + for k, v in data.items(): + table_data[k] = { + kk: vv for kk, vv in v.items() if kk not in ("curves", "score-histograms") + } + + example = next(iter(table_data.values())) headers = list(example.keys()) - table = [[k[h] for h in headers] for k in data.values()] + table = [[k[h] for h in headers] for k in table_data.values()] # add subset names headers = ["subset"] + headers - table = [[name] + k for name, k in zip(data.keys(), table)] + table = [[name] + k for name, k in zip(table_data.keys(), table)] return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f") -def aggregate_roc( - data: typing.Mapping[str, typing.Any], - title: str = "ROC", +def _score_plot( + histograms: dict[str, dict[str, numpy.typing.NDArray]], + title: str, + threshold: float | None, ) -> matplotlib.figure.Figure: - """Aggregate ROC curves from multiple splits. - - This function produces a single ROC plot for multiple curves generated per - split. + """Plot the normalized score distributions for all systems. Parameters ---------- - data - A dictionary mapping split names to ROC curve data produced by - :py:func:sklearn.metrics.roc_curve`. + histograms + A dictionary containing all histograms that should be inserted into the + plot. Each histogram should itself be setup as another dictionary + containing the keys ``hist`` and ``bin_edges`` as returned by + :py:func:`numpy.histogram`. title - The title of the plot. + Title of the plot. + threshold + Shows where the threshold is in the figure. If set to ``None``, then + does not show the threshold line. Returns ------- matplotlib.figure.Figure - A figure, containing the aggregated ROC plot. + A single (matplotlib) plot containing the score distribution, ready to + be saved to disk or displayed. """ + from matplotlib.ticker import MaxNLocator + fig, ax = plt.subplots(1, 1) assert isinstance(fig, matplotlib.figure.Figure) + ax = typing.cast(matplotlib.axes.Axes, ax) # gets editor to behave - # Names and bounds - ax.set_xlabel("1 - specificity") - ax.set_ylabel("Sensitivity") - ax.set_xlim([0.0, 1.0]) - ax.set_ylim([0.0, 1.0]) + # Here, we configure the "style" of our plot + ax.set_xlim((0, 1)) ax.set_title(title) + ax.set_xlabel("Score") + ax.set_ylabel("Count") - # we should see some of ax 1 ax - ax.spines["right"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.spines["left"].set_position(("data", -0.015)) - ax.spines["bottom"].set_position(("data", -0.015)) + # Only show ticks on the left and bottom spines + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) + ax.get_xaxis().tick_bottom() + ax.get_yaxis().tick_left() + ax.get_yaxis().set_major_locator(MaxNLocator(integer=True)) + # Setup the grid ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) + ax.get_xaxis().grid(False) - plt.tight_layout() - - lines = ["-", "--", "-.", ":"] - colors = [ - "#1f77b4", - "#ff7f0e", - "#2ca02c", - "#d62728", - "#9467bd", - "#8c564b", - "#e377c2", - "#7f7f7f", - "#bcbd22", - "#17becf", - ] - colorcycler = itertools.cycle(colors) - linecycler = itertools.cycle(lines) - - legend = [] - - for name, elements in data.items(): - # plots roc curve - _auc = sklearn.metrics.auc(elements["fpr"], elements["tpr"]) - label = f"{name} (AUC={_auc:.2f})" - color = next(colorcycler) - style = next(linecycler) - - (line,) = ax.plot( - elements["fpr"], - elements["tpr"], - color=color, - linestyle=style, - ) - legend.append((line, label)) - - if len(legend) > 1: - ax.legend( - [k[0] for k in legend], - [k[1] for k in legend], - loc="lower right", - fancybox=True, - framealpha=0.7, - ) + max_hist = 0 + for name in histograms.keys(): + hist = histograms[name]["hist"] + bin_edges = histograms[name]["bin_edges"] + width = 0.7 * (bin_edges[1] - bin_edges[0]) + center = (bin_edges[:-1] + bin_edges[1:]) / 2 + ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7) + max_hist = max(max_hist, hist.max()) - return fig + # Detach axes from the plot + ax.spines["left"].set_position(("data", -0.015)) + ax.spines["bottom"].set_position(("data", -0.015 * max_hist)) + if threshold is not None: + # Adds threshold line (dotted red) + ax.axvline( + threshold, # type: ignore + color="red", + lw=2, + alpha=0.75, + ls="dotted", + label="threshold", + ) -@contextlib.contextmanager -def _precision_recall_canvas() -> ( - Iterator[tuple[matplotlib.figure.Figure, matplotlib.figure.Axes]] -): - """Generate a canvas to draw precision-recall curves. + # Adds a nice legend + ax.legend( + fancybox=True, + framealpha=0.7, + ) - Works like a context manager, yielding a figure and an axes set in which - the precision-recall curves should be added to. The figure already - contains F1-ISO lines and is preset to a 0-1 square region. Once the - context is finished, ``fig.tight_layout()`` is called. + # Makes sure the figure occupies most of the possible space + fig.tight_layout() - Yields - ------ - figure - The figure that should be finally returned to the user. - axes - An axis set where to precision-recall plots should be added to. - """ + return fig - fig, axes1 = plt.subplots(1) - assert isinstance(fig, matplotlib.figure.Figure) - assert isinstance(axes1, matplotlib.figure.Axes) - - # Names and bounds - axes1.set_xlabel("Recall") - axes1.set_ylabel("Precision") - axes1.set_xlim([0.0, 1.0]) - axes1.set_ylim([0.0, 1.0]) - - axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) - axes2 = axes1.twinx() - - # Annotates plot with F1-score iso-lines - f_scores = numpy.linspace(0.1, 0.9, num=9) - tick_locs = [] - tick_labels = [] - for f_score in f_scores: - x = numpy.linspace(0.01, 1) - y = f_score * x / (2 * x - f_score) - plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1) - tick_locs.append(y[-1]) - tick_labels.append(f"{f_score:.1f}") - axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False) - axes2.set_ylabel("iso-F", color="green", alpha=0.3) - axes2.set_ylim([0.0, 1.0]) - axes2.yaxis.set_label_coords(1.015, 0.97) - axes2.set_yticks(tick_locs) # notice these are invisible - for k in axes2.set_yticklabels(tick_labels): - k.set_color("green") - k.set_alpha(0.3) - k.set_size(8) - - # we should see some of axes 1 axes - axes1.spines["right"].set_visible(False) - axes1.spines["top"].set_visible(False) - axes1.spines["left"].set_position(("data", -0.015)) - axes1.spines["bottom"].set_position(("data", -0.015)) - - # we shouldn't see any of axes 2 axes - axes2.spines["right"].set_visible(False) - axes2.spines["top"].set_visible(False) - axes2.spines["left"].set_visible(False) - axes2.spines["bottom"].set_visible(False) - - # yield execution, lets user draw precision-recall plots, and the legend - # before tighteneing the layout - yield fig, axes1 - - plt.tight_layout() - - -def aggregate_pr( - data: typing.Mapping[str, typing.Any], - title: str = "Precision-Recall Curve", -) -> matplotlib.figure.Figure: - """Aggregate PR curves from multiple splits. - This function produces a single Precision-Recall plot for multiple curves - generated per split. The plot will be annotated with F1-score iso-lines (in - which the F1-score maintains the same value). +def make_plots(results: dict[str, dict[str, typing.Any]]) -> list: + """Create plots for all curves and score distributions in ``results``. Parameters ---------- - data - A dictionary mapping split names to Precision-Recall curve data produced by - :py:func:sklearn.metrics.precision_recall_curve`. - title - The title of the plot. + results + Evaluation data as returned by :py:func:`run_binary`. Returns ------- - matplotlib.figure.Figure - A figure, containing the aggregated PR plot. + A list of figures to record to file """ - lines = ["-", "--", "-.", ":"] - colors = [ - "#1f77b4", - "#ff7f0e", - "#2ca02c", - "#d62728", - "#9467bd", - "#8c564b", - "#e377c2", - "#7f7f7f", - "#bcbd22", - "#17becf", - ] - colorcycler = itertools.cycle(colors) - linecycler = itertools.cycle(lines) - - with _precision_recall_canvas() as (fig, axes): - axes.set_title(title) - legend = [] - - for name, elements in data.items(): - _ap = credible.curves.average_metric( - (elements["precision"], elements["recall"]), + retval = [] + + with credible.plot.tight_layout( + ("False Positive Rate", "True Positive Rate"), "ROC" + ) as (fig, ax): + for split_name, data in results.items(): + _auroc = credible.curves.area_under_the_curve( + (data["curves"]["roc"]["fpr"], data["curves"]["roc"]["tpr"]), ) - label = f"{name} (AP={_ap:.2f})" - color = next(colorcycler) - style = next(linecycler) - - (line,) = axes.plot( - elements["recall"], - elements["precision"], - color=color, - linestyle=style, + ax.plot( + data["curves"]["roc"]["fpr"], + data["curves"]["roc"]["tpr"], + label=f"{split_name} (AUC: {_auroc:.2f})", ) - legend.append((line, label)) - - if len(legend) > 1: - axes.legend( - [k[0] for k in legend], - [k[1] for k in legend], - loc="lower left", - fancybox=True, - framealpha=0.7, + ax.legend(loc="best", fancybox=True, framealpha=0.7) + retval.append(fig) + + with credible.plot.tight_layout_f1iso( + ("Recall", "Precision"), "Precison-Recall" + ) as (fig, ax): + for split_name, data in results.items(): + _ap = credible.curves.average_metric( + (data["precision"], data["recall"]), + ) + ax.plot( + data["curves"]["precision_recall"]["recall"], + data["curves"]["precision_recall"]["precision"], + label=f"{split_name} (AP: {_ap:.2f})", ) + ax.legend(loc="best", fancybox=True, framealpha=0.7) + retval.append(fig) + + # score plots + for split_name, data in results.items(): + score_fig = _score_plot( + data["score-histograms"], + f"Score distribution (split: {split_name})", + data["threshold"], + ) + retval.append(score_fig) - return fig + return retval diff --git a/src/mednet/libs/classification/scripts/evaluate.py b/src/mednet/libs/classification/scripts/evaluate.py index 31c2c60a2246e9e1b84b236757a27d6a9c61f755..45bf0ae6412fd00cfe47ca202a08dabeba2189bc 100644 --- a/src/mednet/libs/classification/scripts/evaluate.py +++ b/src/mednet/libs/classification/scripts/evaluate.py @@ -114,18 +114,16 @@ def evaluate( import json import typing - from matplotlib.backends.backend_pdf import PdfPages + import matplotlib.backends.backend_pdf from mednet.libs.common.scripts.utils import ( save_json_metadata, save_json_with_backup, ) from ..engine.evaluator import ( - aggregate_pr, - aggregate_roc, + make_plots, + make_table, run_binary, - score_plot, - tabulate_results, ) evaluation_file = output_folder / "evaluation.json" @@ -176,40 +174,20 @@ def evaluate( logger.info(f"Saving evaluation results at `{str(evaluation_file)}`...") save_json_with_backup(evaluation_file, results) - # dump evaluation results in RST format to screen and file - table_data = {} - for k, v in results.items(): - table_data[k] = { - kk: vv for kk, vv in v.items() if kk not in ("curves", "score-histograms") - } - table = tabulate_results(table_data, fmt="rst") + # Produces and records table + table = make_table(results, "rst") click.echo(table) - table_path = evaluation_file.with_suffix(".rst") - logger.info( - f"Saving evaluation results in table format at `{table_path}`...", - ) - with table_path.open("w") as f: + output_table = evaluation_file.with_suffix(".rst") + logger.info(f"Saving tabulated performance summary at `{str(output_table)}`...") + output_table.parent.mkdir(parents=True, exist_ok=True) + with output_table.open("w") as f: f.write(table) + # Plots pre-calculated curves, if the user asked to do so. if plot: figure_path = evaluation_file.with_suffix(".pdf") logger.info(f"Saving evaluation figures at `{str(figure_path)}`...") - - with PdfPages(figure_path) as pdf: - pr_curves = {k: v["curves"]["precision_recall"] for k, v in results.items()} - pr_fig = aggregate_pr(pr_curves) - pdf.savefig(pr_fig) - - roc_curves = {k: v["curves"]["roc"] for k, v in results.items()} - roc_fig = aggregate_roc(roc_curves) - pdf.savefig(roc_fig) - - # score plots - for k, v in results.items(): - score_fig = score_plot( - v["score-histograms"], - f"Score distribution (split: {k})", - v["threshold"], - ) - pdf.savefig(score_fig) + with matplotlib.backends.backend_pdf.PdfPages(figure_path) as pdf: + for fig in make_plots(results): + pdf.savefig(fig) diff --git a/src/mednet/libs/segmentation/engine/evaluator.py b/src/mednet/libs/segmentation/engine/evaluator.py index c467312d5841c358c863dc5052ceea560731ec20..5e6ae0e503c158a8a7793b5eb7359a9efda3f21e 100644 --- a/src/mednet/libs/segmentation/engine/evaluator.py +++ b/src/mednet/libs/segmentation/engine/evaluator.py @@ -739,7 +739,7 @@ def make_table( # terminal-style table, format, and print to screen. Record the table into # a file for later usage. metrics_available = list(typing.get_args(SUPPORTED_METRIC_TYPE)) - table_headers = ["Dataset", "threshold"] + metrics_available + ["auroc", "avgprec"] + table_headers = ["subset", "threshold"] + metrics_available + ["auroc", "avgprec"] table_data = [] for split_name, data in eval_data.items(): @@ -766,7 +766,7 @@ def make_table( ) -def make_plots(eval_data): +def make_plots(eval_data: dict[str, dict[str, typing.Any]]) -> list: """Create plots for all curves in ``eval_data``. Parameters @@ -810,17 +810,17 @@ def make_plots(eval_data): ) as (fig, ax): for split_name, data in eval_data.items(): ax.plot( - data["curves"]["precision_recall"]["precision"], data["curves"]["precision_recall"]["recall"], + data["curves"]["precision_recall"]["precision"], label=f"{split_name} (AP: {data['average_precision_score']:.2f})", ) if "second_annotator" in data: - precision = data["second_annotator"]["precision"] recall = data["second_annotator"]["recall"] + precision = data["second_annotator"]["precision"] ax.plot( - precision, recall, + precision, linestyle="none", marker="*", markersize=8, diff --git a/tests/classification/test_cli.py b/tests/classification/test_cli.py index 391fa818db6233f419c2d1cdcf29d87fd5f6040d..865b166c465b6533e8fe63caf4d14d0c572538a9 100644 --- a/tests/classification/test_cli.py +++ b/tests/classification/test_cli.py @@ -322,7 +322,7 @@ def test_evaluate_pasa_montgomery(session_tmp_path): r"^Setting --threshold=.*$": 1, r"^Computing performance on split .*...$": 3, r"^Saving evaluation results at .*$": 1, - r"^Saving evaluation results in table format at .*$": 1, + r"^Saving tabulated performance summary at .*$": 1, r"^Saving evaluation figures at .*$": 1, } buf.seek(0)