From 5ce011d246e536d6ee3afaa7147e00cab64c0244 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 19 Oct 2023 09:56:46 +0200
Subject: [PATCH] [scripts.evaluate_saliencymaps] Add script and functionality

---
 src/ptbench/engine/saliency/evaluator.py     | 307 +++++++++++++++++++
 src/ptbench/scripts/cli.py                   |   4 +-
 src/ptbench/scripts/evaluate_saliencymaps.py | 115 +++++++
 3 files changed, 424 insertions(+), 2 deletions(-)
 create mode 100644 src/ptbench/engine/saliency/evaluator.py
 create mode 100644 src/ptbench/scripts/evaluate_saliencymaps.py

diff --git a/src/ptbench/engine/saliency/evaluator.py b/src/ptbench/engine/saliency/evaluator.py
new file mode 100644
index 00000000..ddbe8013
--- /dev/null
+++ b/src/ptbench/engine/saliency/evaluator.py
@@ -0,0 +1,307 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import typing
+
+import matplotlib.figure
+import numpy
+import numpy.typing
+import tabulate
+
+from ...models.typing import SaliencyMapAlgorithm
+
+
+def _reconcile_metrics(
+    completeness: list,
+    interpretability: list,
+) -> list[tuple[str, int, float, float, float]]:
+    """Summarizes samples into a new table containing most important scores.
+
+    It returns a list containing a table with completeness and road scorse per
+    sample, for the selected dataset.  Only samples for which a completness and
+    interpretability scores are availble are returned in the reconciled list.
+
+
+    Parameters
+    ----------
+    completeness
+        A dictionary containing various tables with the sample name and
+        completness (ROAD) scores.
+    interpretability
+        A dictionary containing various tables with the sample name and
+        interpretability (Pro. Energy) scores.
+
+
+    Returns
+    -------
+        A list containing a table with the sample name, target label,
+        completeness score (Average ROAD across different ablation thresholds),
+        interpretability score (Proportional Energy), and the ROAD-Weighted
+        Proportional Energy score.  The ROAD-Weighted Prop. Energy score is
+        defined as:
+
+        .. math::
+
+           \\text{ROAD-WeightedPropEng} = \\max(0, \\text{AvgROAD}) \\cdot
+           \\text{ProportionalEnergy}
+    """
+    retval: list[tuple[str, int, float, float, float]] = []
+
+    retval = []
+    for compl_info, interp_info in zip(completeness, interpretability):
+        # ensure matching sample name and label
+        assert compl_info[0] == interp_info[0]
+        assert compl_info[1] == interp_info[1]
+
+        if len(compl_info) == len(interp_info) == 2:
+            # there is no data.
+            continue
+
+        aopc_combined = compl_info[5]
+        prop_energy = interp_info[4]
+        road_weighted_prop_energy = max(0, aopc_combined) * prop_energy
+
+        retval.append(
+            (
+                compl_info[0],
+                compl_info[1],
+                aopc_combined,
+                prop_energy,
+                road_weighted_prop_energy,
+            )
+        )
+
+    return retval
+
+
+def _make_histogram(
+    name: str,
+    values: numpy.typing.NDArray,
+    xlim: tuple[float, float] | None = None,
+    title: None | str = None,
+) -> matplotlib.figure.Figure:
+    """Builds an histogram of values.
+
+    Parameters
+    ----------
+    name
+        Name of the variable to be histogrammed (will appear in the figure)
+    values
+        Values to be histogrammed
+    xlim
+        A tuple representing the X-axis maximum and minimum to plot. If not
+        set, then use the bin boundaries.
+    title
+        A title to set on the histogram
+
+
+    Returns
+    -------
+        A matplotlib figure containing the histogram.
+    """
+
+    from matplotlib import pyplot
+
+    fig, ax = pyplot.subplots(1)
+    ax = typing.cast(matplotlib.figure.Axes, ax)
+    ax.set_xlabel(name)
+    ax.set_ylabel("Frequency")
+
+    if title is not None:
+        ax.set_title(title)
+    else:
+        ax.set_title(f"{name} Frequency Histogram")
+
+    n, bins, _ = ax.hist(values, bins="auto", density=True, alpha=0.7)
+
+    if xlim is not None:
+        ax.spines.bottom.set_bounds(*xlim)
+    else:
+        ax.spines.bottom.set_bounds(bins[0], bins[-1])
+
+    ax.spines.left.set_bounds(0, n.max())
+    ax.spines.right.set_visible(False)
+    ax.spines.top.set_visible(False)
+
+    ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.3)
+
+    # draw median and quartiles
+    quartile = numpy.percentile(values, [25, 50, 75])
+    ax.axvline(
+        quartile[0], color="green", linestyle="--", label="Q1", alpha=0.5
+    )
+    ax.axvline(quartile[1], color="red", label="median", alpha=0.5)
+    ax.axvline(
+        quartile[2], color="green", linestyle="--", label="Q3", alpha=0.5
+    )
+
+    return fig  # type: ignore
+
+
+def summary_table(
+    summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]], fmt: str
+) -> str:
+    """Tabulates various summaries into one table.
+
+    Parameters
+    ----------
+    summary
+        A dictionary mapping saliency algorithm names to the results of
+        :py:func:`run`.
+    fmt
+        One of the formats supported by `python-tabulate
+        <https://pypi.org/project/tabulate/>`_.
+
+
+    Returns
+    -------
+        A string containing the tabulated information.
+    """
+
+    headers = [
+        "Algorithm",
+        "AOPC-Combined",
+        "Prop. Energy",
+        "ROAD-Normalised",
+    ]
+    table = [
+        [
+            k,
+            v["aopc-combined"]["quartiles"][50],
+            v["proportional-energy"]["quartiles"][50],
+            v["road-normalised-proportional-energy-average"],
+        ]
+        for k, v in summary.items()
+    ]
+    return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
+
+
+def _extract_statistics(
+    algo: SaliencyMapAlgorithm,
+    data: list[tuple[str, int, float, float, float]],
+    name: str,
+    index: int,
+    dataset: str,
+    xlim: tuple[float, float] | None = None,
+) -> dict[str, typing.Any]:
+    """Extracts all meaningful statistics from a reconciled statistics set.
+
+    Parameters
+    ----------
+    algo
+        The algorithm for saliency map estimation that is being analysed.
+    data
+        A list of tuples each containing a sample name, target, and values
+        produced by completeness and interpretability analysis as returned by
+        :py:func:`_reconcile_metrics`.
+    name
+        The name of the variable being analysed
+    index
+        Which of the indexes on the tuples containing in ``data`` that should
+        be extracted.
+    dataset
+        The name of the dataset being analysed
+    xlim
+        Limits for histogram plotting
+
+
+    Returns
+    -------
+        A dictionary containing the following elements:
+
+        * ``values``: A list of values corresponding to the index on the data
+        * ``mean``: The mean of the value listdir
+        * ``stdev``: The standard deviation of the value list
+        * ``quartiles``: The 25%, 50% (median), and 75% quartile of values
+        * ``plot``: An histogram of values
+        * ``decreasing_scores``: A list of sample names and labels in
+          decreasing value.
+    """
+
+    val = numpy.array([k[index] for k in data])
+    return dict(
+        values=val,
+        mean=val.mean(),
+        stdev=val.std(ddof=1),  # unbiased estimator
+        quartiles={
+            25: numpy.percentile(val, 25),  # type: ignore
+            50: numpy.median(val),  # type: ignore
+            75: numpy.percentile(val, 75),  # type: ignore
+        },
+        plot=_make_histogram(
+            name,
+            val,
+            xlim=xlim,
+            title=f"{name} Frequency Histogram ({algo} @ {dataset})",
+        ),
+        decreasing_scores=[
+            (k[0], k[index])
+            for k in sorted(data, key=lambda x: x[index], reverse=True)
+        ],
+    )
+
+
+def run(
+    saliency_map_algorithm: SaliencyMapAlgorithm,
+    completeness: dict[str, list],
+    interpretability: dict[str, list],
+) -> dict[str, typing.Any]:
+    """Evaluates multiple saliency map algorithms and produces summarized
+    results.
+
+    Parameters
+    ----------
+    saliency_map_algorithm
+        The algorithm for saliency map estimation that is being analysed.
+    completeness
+        A dictionary mapping dataset names to tables with the sample name and
+        completness (among which Average ROAD) scores.
+    interpretability
+        A dictionary mapping dataset names to tables with the sample name and
+        interpretability (among which Prop. Energy) scores.
+
+
+    Returns
+    -------
+        A dictionary with most important statistical values for the main
+        completeness (AOPC-Combined), interpretability (Prop. Energy), and a
+        combination of both (ROAD-Weighted Prop. Energy) scores.
+    """
+
+    retval: dict = {}
+
+    for dataset, compl_data in completeness.items():
+        reconciled = _reconcile_metrics(compl_data, interpretability[dataset])
+        d = {}
+
+        d["aopc-combined"] = _extract_statistics(
+            algo=saliency_map_algorithm,
+            data=reconciled,
+            name="AOPC-Combined",
+            index=2,
+            dataset=dataset,
+        )
+        d["proportional-energy"] = _extract_statistics(
+            algo=saliency_map_algorithm,
+            data=reconciled,
+            name="Prop.Energy",
+            index=3,
+            dataset=dataset,
+            xlim=(0, 1),
+        )
+        d["road-weighted-proportional-energy"] = _extract_statistics(
+            algo=saliency_map_algorithm,
+            data=reconciled,
+            name="ROAD-weighted-Prop.Energy",
+            index=4,
+            dataset=dataset,
+        )
+
+        d["road-normalised-proportional-energy-average"] = sum(
+            retval["road-weighted-proportional-energy"]["val"]
+        ) / sum([max(0, k) for k in retval["aopc-combined"]["val"]])
+
+        retval[dataset] = d
+
+    return retval
diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py
index ff3611be..0311bef5 100644
--- a/src/ptbench/scripts/cli.py
+++ b/src/ptbench/scripts/cli.py
@@ -11,7 +11,7 @@ from . import (
     config,
     database,
     evaluate,
-    evaluatevis,
+    evaluate_saliencymaps,
     experiment,
     generate_saliencymaps,
     predict,
@@ -38,7 +38,7 @@ cli.add_command(database.database)
 cli.add_command(evaluate.evaluate)
 cli.add_command(saliency_completeness.saliency_completeness)
 cli.add_command(saliency_interpretability.saliency_interpretability)
-cli.add_command(evaluatevis.evaluatevis)
+cli.add_command(evaluate_saliencymaps.evaluate_saliencymaps)
 cli.add_command(experiment.experiment)
 cli.add_command(generate_saliencymaps.generate_saliencymaps)
 cli.add_command(predict.predict)
diff --git a/src/ptbench/scripts/evaluate_saliencymaps.py b/src/ptbench/scripts/evaluate_saliencymaps.py
new file mode 100644
index 00000000..5b96b37a
--- /dev/null
+++ b/src/ptbench/scripts/evaluate_saliencymaps.py
@@ -0,0 +1,115 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pathlib
+import typing
+
+import click
+
+from clapper.click import ResourceOption, verbosity_option
+from clapper.logging import setup
+
+from ..models.typing import SaliencyMapAlgorithm
+from .click import ConfigCommand
+
+# avoids X11/graphical desktop requirement when creating plots
+__import__("matplotlib").use("agg")
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+@click.command(
+    entry_point_group="ptbench.config",
+    cls=ConfigCommand,
+    epilog="""Examples:
+
+\b
+    1. Tabulates and generates plots for two saliency map algorithms:
+
+       .. code:: sh
+
+          ptbench evaluate-saliencymaps -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json
+""",
+)
+@click.option(
+    "--entry",
+    "-e",
+    required=True,
+    multiple=True,
+    help=f"ENTRY is a triplet containing the algorithm name, the path to the "
+    f"scores issued from the completness analysis (``ptbench "
+    f"saliency-completness``) and scores issued from the interpretability "
+    f"analysis (``ptbench saliency-interpretability``), both in JSON format. "
+    f"Paths to score files must exist before the program is called. Valid values "
+    f"for saliency map algorithms are "
+    f"{'|'.join(typing.get_args(SaliencyMapAlgorithm))}",
+    type=(
+        click.Choice(
+            typing.get_args(SaliencyMapAlgorithm), case_sensitive=False
+        ),
+        click.Path(
+            exists=True,
+            file_okay=True,
+            dir_okay=False,
+            path_type=pathlib.Path,
+        ),
+        click.Path(
+            exists=True,
+            file_okay=True,
+            dir_okay=False,
+            path_type=pathlib.Path,
+        ),
+    ),
+    cls=ResourceOption,
+)
+@click.option(
+    "--output-folder",
+    "-o",
+    help="Path where to store the analysis result (created if does not exist)",
+    required=False,
+    default="results",
+    type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
+    cls=ResourceOption,
+)
+@verbosity_option(logger=logger, expose_value=False)
+def evaluate_saliencymaps(
+    entry,
+    output_folder,
+    **_,  # ignored
+) -> None:
+    """Calculates summary statistics for a saliency map algorithm."""
+    import json
+
+    from matplotlib.backends.backend_pdf import PdfPages
+
+    from ..engine.saliency.evaluator import run, summary_table
+
+    summary = {
+        algo: run(algo, json.load(complet.open()), json.load(interp.open()))
+        for algo, complet, interp in entry
+    }
+    table = summary_table(summary, "rst")
+    click.echo(summary)
+
+    if output_folder is not None:
+        output_folder.mkdir(parents=True, exist_ok=True)
+
+        table_path = output_folder / "summary.rst"
+
+        logger.info(f"Saving summary table at `{table_path}`...")
+        with table_path.open("w") as f:
+            f.write(table)
+
+        figure_path = output_folder / "plots.pdf"
+        logger.info(f"Saving figures at `{figure_path}`...")
+
+        with PdfPages(figure_path) as pdf:
+            for dataset in summary.keys():
+                pdf.savefig(summary[dataset]["aopc-combined"]["plot"])
+                pdf.savefig(summary[dataset]["proportional-energy"]["plot"])
+                pdf.savefig(
+                    summary[dataset]["road-weighted-proportional-energy"][
+                        "plot"
+                    ]
+                )
-- 
GitLab