Skip to content
Snippets Groups Projects
Commit 5ce011d2 authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[scripts.evaluate_saliencymaps] Add script and functionality

parent 71536f42
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import typing
import matplotlib.figure
import numpy
import numpy.typing
import tabulate
from ...models.typing import SaliencyMapAlgorithm
def _reconcile_metrics(
completeness: list,
interpretability: list,
) -> list[tuple[str, int, float, float, float]]:
"""Summarizes samples into a new table containing most important scores.
It returns a list containing a table with completeness and road scorse per
sample, for the selected dataset. Only samples for which a completness and
interpretability scores are availble are returned in the reconciled list.
Parameters
----------
completeness
A dictionary containing various tables with the sample name and
completness (ROAD) scores.
interpretability
A dictionary containing various tables with the sample name and
interpretability (Pro. Energy) scores.
Returns
-------
A list containing a table with the sample name, target label,
completeness score (Average ROAD across different ablation thresholds),
interpretability score (Proportional Energy), and the ROAD-Weighted
Proportional Energy score. The ROAD-Weighted Prop. Energy score is
defined as:
.. math::
\\text{ROAD-WeightedPropEng} = \\max(0, \\text{AvgROAD}) \\cdot
\\text{ProportionalEnergy}
"""
retval: list[tuple[str, int, float, float, float]] = []
retval = []
for compl_info, interp_info in zip(completeness, interpretability):
# ensure matching sample name and label
assert compl_info[0] == interp_info[0]
assert compl_info[1] == interp_info[1]
if len(compl_info) == len(interp_info) == 2:
# there is no data.
continue
aopc_combined = compl_info[5]
prop_energy = interp_info[4]
road_weighted_prop_energy = max(0, aopc_combined) * prop_energy
retval.append(
(
compl_info[0],
compl_info[1],
aopc_combined,
prop_energy,
road_weighted_prop_energy,
)
)
return retval
def _make_histogram(
name: str,
values: numpy.typing.NDArray,
xlim: tuple[float, float] | None = None,
title: None | str = None,
) -> matplotlib.figure.Figure:
"""Builds an histogram of values.
Parameters
----------
name
Name of the variable to be histogrammed (will appear in the figure)
values
Values to be histogrammed
xlim
A tuple representing the X-axis maximum and minimum to plot. If not
set, then use the bin boundaries.
title
A title to set on the histogram
Returns
-------
A matplotlib figure containing the histogram.
"""
from matplotlib import pyplot
fig, ax = pyplot.subplots(1)
ax = typing.cast(matplotlib.figure.Axes, ax)
ax.set_xlabel(name)
ax.set_ylabel("Frequency")
if title is not None:
ax.set_title(title)
else:
ax.set_title(f"{name} Frequency Histogram")
n, bins, _ = ax.hist(values, bins="auto", density=True, alpha=0.7)
if xlim is not None:
ax.spines.bottom.set_bounds(*xlim)
else:
ax.spines.bottom.set_bounds(bins[0], bins[-1])
ax.spines.left.set_bounds(0, n.max())
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.3)
# draw median and quartiles
quartile = numpy.percentile(values, [25, 50, 75])
ax.axvline(
quartile[0], color="green", linestyle="--", label="Q1", alpha=0.5
)
ax.axvline(quartile[1], color="red", label="median", alpha=0.5)
ax.axvline(
quartile[2], color="green", linestyle="--", label="Q3", alpha=0.5
)
return fig # type: ignore
def summary_table(
summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]], fmt: str
) -> str:
"""Tabulates various summaries into one table.
Parameters
----------
summary
A dictionary mapping saliency algorithm names to the results of
:py:func:`run`.
fmt
One of the formats supported by `python-tabulate
<https://pypi.org/project/tabulate/>`_.
Returns
-------
A string containing the tabulated information.
"""
headers = [
"Algorithm",
"AOPC-Combined",
"Prop. Energy",
"ROAD-Normalised",
]
table = [
[
k,
v["aopc-combined"]["quartiles"][50],
v["proportional-energy"]["quartiles"][50],
v["road-normalised-proportional-energy-average"],
]
for k, v in summary.items()
]
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
def _extract_statistics(
algo: SaliencyMapAlgorithm,
data: list[tuple[str, int, float, float, float]],
name: str,
index: int,
dataset: str,
xlim: tuple[float, float] | None = None,
) -> dict[str, typing.Any]:
"""Extracts all meaningful statistics from a reconciled statistics set.
Parameters
----------
algo
The algorithm for saliency map estimation that is being analysed.
data
A list of tuples each containing a sample name, target, and values
produced by completeness and interpretability analysis as returned by
:py:func:`_reconcile_metrics`.
name
The name of the variable being analysed
index
Which of the indexes on the tuples containing in ``data`` that should
be extracted.
dataset
The name of the dataset being analysed
xlim
Limits for histogram plotting
Returns
-------
A dictionary containing the following elements:
* ``values``: A list of values corresponding to the index on the data
* ``mean``: The mean of the value listdir
* ``stdev``: The standard deviation of the value list
* ``quartiles``: The 25%, 50% (median), and 75% quartile of values
* ``plot``: An histogram of values
* ``decreasing_scores``: A list of sample names and labels in
decreasing value.
"""
val = numpy.array([k[index] for k in data])
return dict(
values=val,
mean=val.mean(),
stdev=val.std(ddof=1), # unbiased estimator
quartiles={
25: numpy.percentile(val, 25), # type: ignore
50: numpy.median(val), # type: ignore
75: numpy.percentile(val, 75), # type: ignore
},
plot=_make_histogram(
name,
val,
xlim=xlim,
title=f"{name} Frequency Histogram ({algo} @ {dataset})",
),
decreasing_scores=[
(k[0], k[index])
for k in sorted(data, key=lambda x: x[index], reverse=True)
],
)
def run(
saliency_map_algorithm: SaliencyMapAlgorithm,
completeness: dict[str, list],
interpretability: dict[str, list],
) -> dict[str, typing.Any]:
"""Evaluates multiple saliency map algorithms and produces summarized
results.
Parameters
----------
saliency_map_algorithm
The algorithm for saliency map estimation that is being analysed.
completeness
A dictionary mapping dataset names to tables with the sample name and
completness (among which Average ROAD) scores.
interpretability
A dictionary mapping dataset names to tables with the sample name and
interpretability (among which Prop. Energy) scores.
Returns
-------
A dictionary with most important statistical values for the main
completeness (AOPC-Combined), interpretability (Prop. Energy), and a
combination of both (ROAD-Weighted Prop. Energy) scores.
"""
retval: dict = {}
for dataset, compl_data in completeness.items():
reconciled = _reconcile_metrics(compl_data, interpretability[dataset])
d = {}
d["aopc-combined"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="AOPC-Combined",
index=2,
dataset=dataset,
)
d["proportional-energy"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="Prop.Energy",
index=3,
dataset=dataset,
xlim=(0, 1),
)
d["road-weighted-proportional-energy"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="ROAD-weighted-Prop.Energy",
index=4,
dataset=dataset,
)
d["road-normalised-proportional-energy-average"] = sum(
retval["road-weighted-proportional-energy"]["val"]
) / sum([max(0, k) for k in retval["aopc-combined"]["val"]])
retval[dataset] = d
return retval
......@@ -11,7 +11,7 @@ from . import (
config,
database,
evaluate,
evaluatevis,
evaluate_saliencymaps,
experiment,
generate_saliencymaps,
predict,
......@@ -38,7 +38,7 @@ cli.add_command(database.database)
cli.add_command(evaluate.evaluate)
cli.add_command(saliency_completeness.saliency_completeness)
cli.add_command(saliency_interpretability.saliency_interpretability)
cli.add_command(evaluatevis.evaluatevis)
cli.add_command(evaluate_saliencymaps.evaluate_saliencymaps)
cli.add_command(experiment.experiment)
cli.add_command(generate_saliencymaps.generate_saliencymaps)
cli.add_command(predict.predict)
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import typing
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from ..models.typing import SaliencyMapAlgorithm
from .click import ConfigCommand
# avoids X11/graphical desktop requirement when creating plots
__import__("matplotlib").use("agg")
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.command(
entry_point_group="ptbench.config",
cls=ConfigCommand,
epilog="""Examples:
\b
1. Tabulates and generates plots for two saliency map algorithms:
.. code:: sh
ptbench evaluate-saliencymaps -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json
""",
)
@click.option(
"--entry",
"-e",
required=True,
multiple=True,
help=f"ENTRY is a triplet containing the algorithm name, the path to the "
f"scores issued from the completness analysis (``ptbench "
f"saliency-completness``) and scores issued from the interpretability "
f"analysis (``ptbench saliency-interpretability``), both in JSON format. "
f"Paths to score files must exist before the program is called. Valid values "
f"for saliency map algorithms are "
f"{'|'.join(typing.get_args(SaliencyMapAlgorithm))}",
type=(
click.Choice(
typing.get_args(SaliencyMapAlgorithm), case_sensitive=False
),
click.Path(
exists=True,
file_okay=True,
dir_okay=False,
path_type=pathlib.Path,
),
click.Path(
exists=True,
file_okay=True,
dir_okay=False,
path_type=pathlib.Path,
),
),
cls=ResourceOption,
)
@click.option(
"--output-folder",
"-o",
help="Path where to store the analysis result (created if does not exist)",
required=False,
default="results",
type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
cls=ResourceOption,
)
@verbosity_option(logger=logger, expose_value=False)
def evaluate_saliencymaps(
entry,
output_folder,
**_, # ignored
) -> None:
"""Calculates summary statistics for a saliency map algorithm."""
import json
from matplotlib.backends.backend_pdf import PdfPages
from ..engine.saliency.evaluator import run, summary_table
summary = {
algo: run(algo, json.load(complet.open()), json.load(interp.open()))
for algo, complet, interp in entry
}
table = summary_table(summary, "rst")
click.echo(summary)
if output_folder is not None:
output_folder.mkdir(parents=True, exist_ok=True)
table_path = output_folder / "summary.rst"
logger.info(f"Saving summary table at `{table_path}`...")
with table_path.open("w") as f:
f.write(table)
figure_path = output_folder / "plots.pdf"
logger.info(f"Saving figures at `{figure_path}`...")
with PdfPages(figure_path) as pdf:
for dataset in summary.keys():
pdf.savefig(summary[dataset]["aopc-combined"]["plot"])
pdf.savefig(summary[dataset]["proportional-energy"]["plot"])
pdf.savefig(
summary[dataset]["road-weighted-proportional-energy"][
"plot"
]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment