Skip to content
Snippets Groups Projects
Commit 7f97368f authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[classify.scripts.saliency] Incorporate evaluation code directly on...

[classify.scripts.saliency] Incorporate evaluation code directly on interpretability/completeness analysis; Remove outdated saliency-evaluation script and engine; Fix test units
parent 08e9f659
No related branches found
No related tags found
1 merge request!46Create common library
Pipeline #89607 failed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import typing
import matplotlib.figure
import numpy
import numpy.typing
import tabulate
from ...models.typing import SaliencyMapAlgorithm
def _reconcile_metrics(
completeness: list,
interpretability: list,
) -> list[tuple[str, int, float, float, float]]:
r"""Summarize samples into a new table containing the most important scores.
It returns a list containing a table with completeness and ROAD scores per
sample, for the selected dataset. Only samples for which a completness and
interpretability scores are availble are returned in the reconciled list.
Parameters
----------
completeness
A list containing various tables with the sample name and
completness (ROAD) scores.
interpretability
A list containing various tables with the sample name and
interpretability (Pro. Energy) scores.
Returns
-------
list[tuple[str, int, float, float, float]]
A list containing a table with the sample name, target label,
completeness score (Average ROAD across different ablation thresholds),
interpretability score (Proportional Energy), and the ROAD-Weighted
Proportional Energy score. The ROAD-Weighted Prop. Energy score is
defined as:
.. math::
\text{ROAD-WeightedPropEng} = \max(0, \text{AvgROAD}) \cdot
\text{ProportionalEnergy}
"""
retval: list[tuple[str, int, float, float, float]] = []
retval = []
for compl_info, interp_info in zip(completeness, interpretability):
# ensure matching sample name and label
assert compl_info[0] == interp_info[0]
assert compl_info[1] == interp_info[1]
if len(compl_info) == len(interp_info) == 2:
# there is no data.
continue
aopc_combined = compl_info[5]
prop_energy = interp_info[2]
road_weighted_prop_energy = max(0, aopc_combined) * prop_energy
retval.append(
(
compl_info[0],
compl_info[1],
aopc_combined,
prop_energy,
road_weighted_prop_energy,
),
)
return retval
def _make_histogram(
name: str,
values: numpy.typing.NDArray,
xlim: tuple[float, float] | None = None,
title: None | str = None,
) -> matplotlib.figure.Figure:
"""Build an histogram of values.
Parameters
----------
name
Name of the variable to be histogrammed (will appear in the figure).
values
Values to be histogrammed.
xlim
A tuple representing the X-axis maximum and minimum to plot. If not
set, then use the bin boundaries.
title
A title to set on the histogram.
Returns
-------
matplotlib.figure.Figure
A matplotlib figure containing the histogram.
"""
from matplotlib import pyplot
fig, ax = pyplot.subplots(1)
ax = typing.cast(matplotlib.figure.Axes, ax)
ax.set_xlabel(name)
ax.set_ylabel("Frequency")
if title is not None:
ax.set_title(title)
else:
ax.set_title(f"{name} Frequency Histogram")
n, bins, _ = ax.hist(values, bins="auto", density=True, alpha=0.7)
if xlim is not None:
ax.spines.bottom.set_bounds(*xlim)
else:
ax.spines.bottom.set_bounds(bins[0], bins[-1])
ax.spines.left.set_bounds(0, n.max())
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.3)
# draw median and quartiles
quartile = numpy.percentile(values, [25, 50, 75])
ax.axvline(
quartile[0],
color="green",
linestyle="--",
label="Q1",
alpha=0.5,
)
ax.axvline(quartile[1], color="red", label="median", alpha=0.5)
ax.axvline(
quartile[2],
color="green",
linestyle="--",
label="Q3",
alpha=0.5,
)
return fig # type: ignore
def summary_table(
summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]],
fmt: str,
) -> str:
"""Tabulate various summaries into one table.
Parameters
----------
summary
A dictionary mapping saliency algorithm names to the results of
:py:func:`run`.
fmt
One of the formats supported by `python-tabulate
<https://pypi.org/project/tabulate/>`_.
Returns
-------
str
A string containing the tabulated information.
"""
headers = [
"Algorithm",
"AOPC-Combined",
"Prop. Energy",
"ROAD-Normalised",
]
table = [
[
k,
v["aopc-combined"]["quartiles"][50],
v["proportional-energy"]["quartiles"][50],
v["road-normalised-proportional-energy-average"],
]
for k, v in summary.items()
]
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
def _extract_statistics(
algo: SaliencyMapAlgorithm,
data: list[tuple[str, int, float, float, float]],
name: str,
index: int,
dataset: str,
xlim: tuple[float, float] | None = None,
) -> dict[str, typing.Any]:
"""Extract all meaningful statistics from a reconciled statistics set.
Parameters
----------
algo
The algorithm for saliency map estimation that is being analysed.
data
A list of tuples each containing a sample name, target, and values
produced by completeness and interpretability analysis as returned by
:py:func:`_reconcile_metrics`.
name
The name of the variable being analysed.
index
The index of the tuple contained in ``data`` that should be extracted.
dataset
The name of the dataset being analysed.
xlim
Limits for histogram plotting.
Returns
-------
dict[str, typing.Any]
A dictionary containing the following elements:
* ``values``: A list of values corresponding to the index on the data
* ``mean``: The mean of the value listdir
* ``stdev``: The standard deviation of the value list
* ``quartiles``: The 25%, 50% (median), and 75% quartile of values
* ``plot``: An histogram of values
* ``decreasing_scores``: A list of sample names and labels in
decreasing value.
"""
val = numpy.array([k[index] for k in data])
return dict(
values=val,
mean=val.mean(),
stdev=val.std(ddof=1), # unbiased estimator
quartiles={
25: numpy.percentile(val, 25), # type: ignore
50: numpy.median(val), # type: ignore
75: numpy.percentile(val, 75), # type: ignore
},
plot=_make_histogram(
name,
val,
xlim=xlim,
title=f"{name} Frequency Histogram ({algo} @ {dataset})",
),
decreasing_scores=[
(k[0], k[index]) for k in sorted(data, key=lambda x: x[index], reverse=True)
],
)
def run(
saliency_map_algorithm: SaliencyMapAlgorithm,
completeness: dict[str, list],
interpretability: dict[str, list],
) -> dict[str, typing.Any]:
"""Evaluate multiple saliency map algorithms and produces summarized
results.
Parameters
----------
saliency_map_algorithm
The algorithm for saliency map estimation that is being analysed.
completeness
A dictionary mapping dataset names to tables with the sample name and
completness (among which Average ROAD) scores.
interpretability
A dictionary mapping dataset names to tables with the sample name and
interpretability (among which Prop. Energy) scores.
Returns
-------
dict[str, typing.Any]
A dictionary with most important statistical values for the main
completeness (AOPC-Combined), interpretability (Prop. Energy), and a
combination of both (ROAD-Weighted Prop. Energy) scores.
"""
retval: dict = {}
for dataset, compl_data in completeness.items():
reconciled = _reconcile_metrics(compl_data, interpretability[dataset])
d = {}
d["aopc-combined"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="AOPC-Combined",
index=2,
dataset=dataset,
)
d["proportional-energy"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="Prop.Energy",
index=3,
dataset=dataset,
xlim=(0, 1),
)
d["road-weighted-proportional-energy"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="ROAD-weighted-Prop.Energy",
index=4,
dataset=dataset,
)
d["road-normalised-proportional-energy-average"] = sum(
retval["road-weighted-proportional-energy"]["val"],
) / sum([max(0, k) for k in retval["aopc-combined"]["val"]])
retval[dataset] = d
return retval
......@@ -40,5 +40,4 @@ classify.add_command(saliency)
_add_command(saliency, ".saliency.generate", "generate")
_add_command(saliency, ".saliency.completeness", "completeness")
_add_command(saliency, ".saliency.interpretability", "interpretability")
_add_command(saliency, ".saliency.evaluate", "evaluate")
_add_command(saliency, ".saliency.view", "view")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
import typing
import click
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from ....scripts.click import ConfigCommand
from ...models.typing import SaliencyMapAlgorithm
# avoids X11/graphical desktop requirement when creating plots
__import__("matplotlib").use("agg")
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.command(
entry_point_group="mednet.config",
cls=ConfigCommand,
epilog="""Examples:
1. Tabulate and generates plots for two saliency map algorithms:
.. code:: sh
mednet classify saliency evaluate -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json
""",
)
@click.option(
"--entry",
"-e",
required=True,
multiple=True,
help=f"ENTRY is a triplet containing the algorithm name, the path to the "
f"scores issued from the completness analysis (``mednet "
f"saliency completness``) and scores issued from the interpretability "
f"analysis (``mednet saliency interpretability``), both in JSON format. "
f"Paths to score files must exist before the program is called. Valid values "
f"for saliency map algorithms are "
f"{'|'.join(typing.get_args(SaliencyMapAlgorithm))}",
type=(
click.Choice(
typing.get_args(SaliencyMapAlgorithm),
case_sensitive=False,
),
click.Path(
exists=True,
file_okay=True,
dir_okay=False,
path_type=pathlib.Path,
),
click.Path(
exists=True,
file_okay=True,
dir_okay=False,
path_type=pathlib.Path,
),
),
cls=ResourceOption,
)
@click.option(
"--output-folder",
"-o",
help="Directory in which to store the analysis result (created if does not exist)",
required=False,
default="results",
type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
cls=ResourceOption,
)
@verbosity_option(logger=logger, expose_value=False)
def evaluate(
entry,
output_folder,
**_, # ignored
) -> None: # numpydoc ignore=PR01
"""Calculate summary statistics for a saliency map algorithm."""
import json
from matplotlib.backends.backend_pdf import PdfPages
from ...engine.saliency.evaluator import run, summary_table
summary = {
algo: run(algo, json.load(complet.open()), json.load(interp.open()))
for algo, complet, interp in entry
}
table = summary_table(summary, "rst")
click.echo(summary)
if output_folder is not None:
output_folder.mkdir(parents=True, exist_ok=True)
table_path = output_folder / "summary.rst"
logger.info(f"Saving summary table at `{table_path}`...")
with table_path.open("w") as f:
f.write(table)
figure_path = output_folder / "plots.pdf"
logger.info(f"Saving figures at `{figure_path}`...")
with PdfPages(figure_path) as pdf:
for dataset in summary.keys():
pdf.savefig(summary[dataset]["aopc-combined"]["plot"])
pdf.savefig(summary[dataset]["proportional-energy"]["plot"])
pdf.savefig(
summary[dataset]["road-weighted-proportional-energy"]["plot"],
)
......@@ -105,7 +105,7 @@ def view(
# register metadata
save_json_metadata(
output_file=output_folder / "saliency-view.meta.json",
output_file=output_folder / "view.meta.json",
datamodule=datamodule,
model=model,
input_folder=input_folder,
......
......@@ -95,12 +95,6 @@ def test_saliency_view_help():
_check_help(view)
def test_saliency_evaluate_help():
from mednet.classify.scripts.saliency.evaluate import evaluate
_check_help(evaluate)
@pytest.mark.slow
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery(session_tmp_path):
......@@ -266,8 +260,10 @@ def test_saliency_generation_pasa_montgomery(session_tmp_path):
runner = CliRunner()
with stdout_logging() as buf:
output_folder = session_tmp_path / "classification-standalone"
last = _get_checkpoint_from_alias(output_folder, "periodic")
saliency_algo = "gradcam"
input_folder = session_tmp_path / "classification-standalone"
last = _get_checkpoint_from_alias(input_folder, "periodic")
output_folder = input_folder / saliency_algo
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
result = runner.invoke(
generate,
......@@ -275,13 +271,14 @@ def test_saliency_generation_pasa_montgomery(session_tmp_path):
"-vv",
"pasa",
"montgomery",
f"--saliency-map-algorithm={saliency_algo}",
f"--weight={str(last)}",
f"--output-folder={str(output_folder)}",
],
)
_assert_exit_0(result)
assert (output_folder / "saliency-generation.meta.json").exists()
assert (output_folder / "generation.meta.json").exists()
keywords = {
r"^Writing run metadata at .*$": 1,
......@@ -305,30 +302,25 @@ def test_saliency_generation_pasa_montgomery(session_tmp_path):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_saliency_view_pasa_montgomery(session_tmp_path):
from mednet.classify.scripts.saliency.view import view
from mednet.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
runner = CliRunner()
with stdout_logging() as buf:
output_folder = session_tmp_path / "classification-standalone"
last = _get_checkpoint_from_alias(output_folder, "periodic")
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
input_folder = session_tmp_path / "classification-standalone" / "gradcam"
output_folder = input_folder / "view"
result = runner.invoke(
view,
[
"-vv",
"pasa",
"montgomery",
f"--input-folder={str(output_folder)}",
f"--input-folder={str(input_folder)}",
f"--output-folder={str(output_folder)}",
],
)
_assert_exit_0(result)
assert (output_folder / "saliency-view.meta.json").exists()
assert (output_folder / "view.meta.json").exists()
keywords = {
r"^Writing run metadata at .*$": 1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment