From 5ce011d246e536d6ee3afaa7147e00cab64c0244 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Thu, 19 Oct 2023 09:56:46 +0200 Subject: [PATCH] [scripts.evaluate_saliencymaps] Add script and functionality --- src/ptbench/engine/saliency/evaluator.py | 307 +++++++++++++++++++ src/ptbench/scripts/cli.py | 4 +- src/ptbench/scripts/evaluate_saliencymaps.py | 115 +++++++ 3 files changed, 424 insertions(+), 2 deletions(-) create mode 100644 src/ptbench/engine/saliency/evaluator.py create mode 100644 src/ptbench/scripts/evaluate_saliencymaps.py diff --git a/src/ptbench/engine/saliency/evaluator.py b/src/ptbench/engine/saliency/evaluator.py new file mode 100644 index 00000000..ddbe8013 --- /dev/null +++ b/src/ptbench/engine/saliency/evaluator.py @@ -0,0 +1,307 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import typing + +import matplotlib.figure +import numpy +import numpy.typing +import tabulate + +from ...models.typing import SaliencyMapAlgorithm + + +def _reconcile_metrics( + completeness: list, + interpretability: list, +) -> list[tuple[str, int, float, float, float]]: + """Summarizes samples into a new table containing most important scores. + + It returns a list containing a table with completeness and road scorse per + sample, for the selected dataset. Only samples for which a completness and + interpretability scores are availble are returned in the reconciled list. + + + Parameters + ---------- + completeness + A dictionary containing various tables with the sample name and + completness (ROAD) scores. + interpretability + A dictionary containing various tables with the sample name and + interpretability (Pro. Energy) scores. + + + Returns + ------- + A list containing a table with the sample name, target label, + completeness score (Average ROAD across different ablation thresholds), + interpretability score (Proportional Energy), and the ROAD-Weighted + Proportional Energy score. The ROAD-Weighted Prop. Energy score is + defined as: + + .. math:: + + \\text{ROAD-WeightedPropEng} = \\max(0, \\text{AvgROAD}) \\cdot + \\text{ProportionalEnergy} + """ + retval: list[tuple[str, int, float, float, float]] = [] + + retval = [] + for compl_info, interp_info in zip(completeness, interpretability): + # ensure matching sample name and label + assert compl_info[0] == interp_info[0] + assert compl_info[1] == interp_info[1] + + if len(compl_info) == len(interp_info) == 2: + # there is no data. + continue + + aopc_combined = compl_info[5] + prop_energy = interp_info[4] + road_weighted_prop_energy = max(0, aopc_combined) * prop_energy + + retval.append( + ( + compl_info[0], + compl_info[1], + aopc_combined, + prop_energy, + road_weighted_prop_energy, + ) + ) + + return retval + + +def _make_histogram( + name: str, + values: numpy.typing.NDArray, + xlim: tuple[float, float] | None = None, + title: None | str = None, +) -> matplotlib.figure.Figure: + """Builds an histogram of values. + + Parameters + ---------- + name + Name of the variable to be histogrammed (will appear in the figure) + values + Values to be histogrammed + xlim + A tuple representing the X-axis maximum and minimum to plot. If not + set, then use the bin boundaries. + title + A title to set on the histogram + + + Returns + ------- + A matplotlib figure containing the histogram. + """ + + from matplotlib import pyplot + + fig, ax = pyplot.subplots(1) + ax = typing.cast(matplotlib.figure.Axes, ax) + ax.set_xlabel(name) + ax.set_ylabel("Frequency") + + if title is not None: + ax.set_title(title) + else: + ax.set_title(f"{name} Frequency Histogram") + + n, bins, _ = ax.hist(values, bins="auto", density=True, alpha=0.7) + + if xlim is not None: + ax.spines.bottom.set_bounds(*xlim) + else: + ax.spines.bottom.set_bounds(bins[0], bins[-1]) + + ax.spines.left.set_bounds(0, n.max()) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) + + ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.3) + + # draw median and quartiles + quartile = numpy.percentile(values, [25, 50, 75]) + ax.axvline( + quartile[0], color="green", linestyle="--", label="Q1", alpha=0.5 + ) + ax.axvline(quartile[1], color="red", label="median", alpha=0.5) + ax.axvline( + quartile[2], color="green", linestyle="--", label="Q3", alpha=0.5 + ) + + return fig # type: ignore + + +def summary_table( + summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]], fmt: str +) -> str: + """Tabulates various summaries into one table. + + Parameters + ---------- + summary + A dictionary mapping saliency algorithm names to the results of + :py:func:`run`. + fmt + One of the formats supported by `python-tabulate + <https://pypi.org/project/tabulate/>`_. + + + Returns + ------- + A string containing the tabulated information. + """ + + headers = [ + "Algorithm", + "AOPC-Combined", + "Prop. Energy", + "ROAD-Normalised", + ] + table = [ + [ + k, + v["aopc-combined"]["quartiles"][50], + v["proportional-energy"]["quartiles"][50], + v["road-normalised-proportional-energy-average"], + ] + for k, v in summary.items() + ] + return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f") + + +def _extract_statistics( + algo: SaliencyMapAlgorithm, + data: list[tuple[str, int, float, float, float]], + name: str, + index: int, + dataset: str, + xlim: tuple[float, float] | None = None, +) -> dict[str, typing.Any]: + """Extracts all meaningful statistics from a reconciled statistics set. + + Parameters + ---------- + algo + The algorithm for saliency map estimation that is being analysed. + data + A list of tuples each containing a sample name, target, and values + produced by completeness and interpretability analysis as returned by + :py:func:`_reconcile_metrics`. + name + The name of the variable being analysed + index + Which of the indexes on the tuples containing in ``data`` that should + be extracted. + dataset + The name of the dataset being analysed + xlim + Limits for histogram plotting + + + Returns + ------- + A dictionary containing the following elements: + + * ``values``: A list of values corresponding to the index on the data + * ``mean``: The mean of the value listdir + * ``stdev``: The standard deviation of the value list + * ``quartiles``: The 25%, 50% (median), and 75% quartile of values + * ``plot``: An histogram of values + * ``decreasing_scores``: A list of sample names and labels in + decreasing value. + """ + + val = numpy.array([k[index] for k in data]) + return dict( + values=val, + mean=val.mean(), + stdev=val.std(ddof=1), # unbiased estimator + quartiles={ + 25: numpy.percentile(val, 25), # type: ignore + 50: numpy.median(val), # type: ignore + 75: numpy.percentile(val, 75), # type: ignore + }, + plot=_make_histogram( + name, + val, + xlim=xlim, + title=f"{name} Frequency Histogram ({algo} @ {dataset})", + ), + decreasing_scores=[ + (k[0], k[index]) + for k in sorted(data, key=lambda x: x[index], reverse=True) + ], + ) + + +def run( + saliency_map_algorithm: SaliencyMapAlgorithm, + completeness: dict[str, list], + interpretability: dict[str, list], +) -> dict[str, typing.Any]: + """Evaluates multiple saliency map algorithms and produces summarized + results. + + Parameters + ---------- + saliency_map_algorithm + The algorithm for saliency map estimation that is being analysed. + completeness + A dictionary mapping dataset names to tables with the sample name and + completness (among which Average ROAD) scores. + interpretability + A dictionary mapping dataset names to tables with the sample name and + interpretability (among which Prop. Energy) scores. + + + Returns + ------- + A dictionary with most important statistical values for the main + completeness (AOPC-Combined), interpretability (Prop. Energy), and a + combination of both (ROAD-Weighted Prop. Energy) scores. + """ + + retval: dict = {} + + for dataset, compl_data in completeness.items(): + reconciled = _reconcile_metrics(compl_data, interpretability[dataset]) + d = {} + + d["aopc-combined"] = _extract_statistics( + algo=saliency_map_algorithm, + data=reconciled, + name="AOPC-Combined", + index=2, + dataset=dataset, + ) + d["proportional-energy"] = _extract_statistics( + algo=saliency_map_algorithm, + data=reconciled, + name="Prop.Energy", + index=3, + dataset=dataset, + xlim=(0, 1), + ) + d["road-weighted-proportional-energy"] = _extract_statistics( + algo=saliency_map_algorithm, + data=reconciled, + name="ROAD-weighted-Prop.Energy", + index=4, + dataset=dataset, + ) + + d["road-normalised-proportional-energy-average"] = sum( + retval["road-weighted-proportional-energy"]["val"] + ) / sum([max(0, k) for k in retval["aopc-combined"]["val"]]) + + retval[dataset] = d + + return retval diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index ff3611be..0311bef5 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -11,7 +11,7 @@ from . import ( config, database, evaluate, - evaluatevis, + evaluate_saliencymaps, experiment, generate_saliencymaps, predict, @@ -38,7 +38,7 @@ cli.add_command(database.database) cli.add_command(evaluate.evaluate) cli.add_command(saliency_completeness.saliency_completeness) cli.add_command(saliency_interpretability.saliency_interpretability) -cli.add_command(evaluatevis.evaluatevis) +cli.add_command(evaluate_saliencymaps.evaluate_saliencymaps) cli.add_command(experiment.experiment) cli.add_command(generate_saliencymaps.generate_saliencymaps) cli.add_command(predict.predict) diff --git a/src/ptbench/scripts/evaluate_saliencymaps.py b/src/ptbench/scripts/evaluate_saliencymaps.py new file mode 100644 index 00000000..5b96b37a --- /dev/null +++ b/src/ptbench/scripts/evaluate_saliencymaps.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib +import typing + +import click + +from clapper.click import ResourceOption, verbosity_option +from clapper.logging import setup + +from ..models.typing import SaliencyMapAlgorithm +from .click import ConfigCommand + +# avoids X11/graphical desktop requirement when creating plots +__import__("matplotlib").use("agg") + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Tabulates and generates plots for two saliency map algorithms: + + .. code:: sh + + ptbench evaluate-saliencymaps -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json +""", +) +@click.option( + "--entry", + "-e", + required=True, + multiple=True, + help=f"ENTRY is a triplet containing the algorithm name, the path to the " + f"scores issued from the completness analysis (``ptbench " + f"saliency-completness``) and scores issued from the interpretability " + f"analysis (``ptbench saliency-interpretability``), both in JSON format. " + f"Paths to score files must exist before the program is called. Valid values " + f"for saliency map algorithms are " + f"{'|'.join(typing.get_args(SaliencyMapAlgorithm))}", + type=( + click.Choice( + typing.get_args(SaliencyMapAlgorithm), case_sensitive=False + ), + click.Path( + exists=True, + file_okay=True, + dir_okay=False, + path_type=pathlib.Path, + ), + click.Path( + exists=True, + file_okay=True, + dir_okay=False, + path_type=pathlib.Path, + ), + ), + cls=ResourceOption, +) +@click.option( + "--output-folder", + "-o", + help="Path where to store the analysis result (created if does not exist)", + required=False, + default="results", + type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), + cls=ResourceOption, +) +@verbosity_option(logger=logger, expose_value=False) +def evaluate_saliencymaps( + entry, + output_folder, + **_, # ignored +) -> None: + """Calculates summary statistics for a saliency map algorithm.""" + import json + + from matplotlib.backends.backend_pdf import PdfPages + + from ..engine.saliency.evaluator import run, summary_table + + summary = { + algo: run(algo, json.load(complet.open()), json.load(interp.open())) + for algo, complet, interp in entry + } + table = summary_table(summary, "rst") + click.echo(summary) + + if output_folder is not None: + output_folder.mkdir(parents=True, exist_ok=True) + + table_path = output_folder / "summary.rst" + + logger.info(f"Saving summary table at `{table_path}`...") + with table_path.open("w") as f: + f.write(table) + + figure_path = output_folder / "plots.pdf" + logger.info(f"Saving figures at `{figure_path}`...") + + with PdfPages(figure_path) as pdf: + for dataset in summary.keys(): + pdf.savefig(summary[dataset]["aopc-combined"]["plot"]) + pdf.savefig(summary[dataset]["proportional-energy"]["plot"]) + pdf.savefig( + summary[dataset]["road-weighted-proportional-energy"][ + "plot" + ] + ) -- GitLab