Skip to content
Snippets Groups Projects
Commit 9486272d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[classify.engine.saliency.interpretability] Implement filter for operating on...

[classify.engine.saliency.interpretability] Implement filter for operating on specific datasets only; Implement basic analysis for interpretability
parent 3bb42fd2
No related branches found
No related tags found
1 merge request!46Create common library
......@@ -58,6 +58,7 @@ Functions that operate on data.
mednet.classify.engine.saliency.generator
mednet.classify.engine.saliency.interpretability
mednet.classify.engine.saliency.viewer
mednet.classify.engine.saliency.utils
mednet.segment.engine.dumper
mednet.segment.engine.evaluator
mednet.segment.engine.predictor
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import typing
import matplotlib.figure
import numpy
import numpy.typing
import tabulate
def extract_statistics(
data: list[tuple[str, int, float, float, float]],
index: int,
) -> dict[str, typing.Any]:
"""Extract all meaningful statistics from a reconciled statistics set.
Parameters
----------
data
A list of tuples each containing a sample name, target, and values
produced by completeness and interpretability analysis as returned by
:py:func:`_reconcile_metrics`.
index
The index of the tuple contained in ``data`` that should be extracted.
Returns
-------
A dictionary containing the following elements:
* ``values``: A list of values corresponding to the index on the data
* ``mean``: The mean of the value listdir
* ``stdev``: The standard deviation of the value list
* ``quartiles``: The 25%, 50% (median), and 75% quartile of values
* ``decreasing_scores``: A list of sample names and labels in
decreasing value.
"""
val = numpy.array([k[index] for k in data])
return dict(
values=val,
mean=val.mean(),
stdev=val.std(ddof=1), # unbiased estimator
quartiles={
25: numpy.percentile(val, 25), # type: ignore
50: numpy.median(val), # type: ignore
75: numpy.percentile(val, 75), # type: ignore
},
decreasing_scores=[
(k[0], k[index]) for k in sorted(data, key=lambda x: x[index], reverse=True)
],
)
def make_table(
results: dict[str, list[typing.Any]],
indexes: dict[int, str],
format_: str,
) -> str:
"""Summarize interpretability results obtained by running :py:func:`run`.
Parameters
----------
results
The results to be summarized.
indexes
A dictionary where keys are indexes in each sample of ``results``, and
values are a (possibly abbreviated) name to be used in table headers.
format_
The table format.
Returns
-------
A table, formatted following ``format_`` and containing the
various quartile informations for each split and metric.
"""
headers = ["subset", "samples"]
for idx, name in indexes.items():
headers += [
f"{name}[mean]",
f"{name}[std]",
f"{name}[25%]",
f"{name}[50%]",
f"{name}[75%]",
]
data = []
for k, v in results.items():
samples = [s for s in v if len(s) != 2]
row = [k, len(samples)]
for idx in indexes.keys():
stats = extract_statistics(samples, index=idx)
row += [
stats["mean"],
stats["stdev"],
stats["quartiles"][25],
stats["quartiles"][50],
stats["quartiles"][75],
]
data.append(row)
return tabulate.tabulate(data, headers, tablefmt=format_, floatfmt=".3f")
def make_histogram(
name: str,
values: numpy.typing.NDArray,
xlim: tuple[float, float] | None = None,
title: None | str = None,
) -> matplotlib.figure.Figure:
"""Build an histogram of values.
Parameters
----------
name
Name of the variable to be histogrammed (will appear in the figure).
values
Values to be histogrammed.
xlim
A tuple representing the X-axis maximum and minimum to plot. If not
set, then use the bin boundaries.
title
A title to set on the histogram.
Returns
-------
A matplotlib figure containing the histogram.
"""
from matplotlib import pyplot
fig, ax = pyplot.subplots(1)
ax = typing.cast(matplotlib.figure.Axes, ax)
ax.set_xlabel(name)
ax.set_ylabel("Frequency")
if title is not None:
ax.set_title(title)
else:
ax.set_title(f"{name} Frequency Histogram")
n, bins, _ = ax.hist(values, bins="auto", density=True, alpha=0.7)
if xlim is not None:
ax.spines.bottom.set_bounds(*xlim)
else:
ax.spines.bottom.set_bounds(bins[0], bins[-1])
ax.spines.left.set_bounds(0, n.max())
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.3)
# draw median and quartiles
quartile = numpy.percentile(values, [25, 50, 75])
ax.axvline(
quartile[0],
color="green",
linestyle="--",
label="Q1",
alpha=0.5,
)
ax.axvline(quartile[1], color="red", label="median", alpha=0.5)
ax.axvline(
quartile[2],
color="green",
linestyle="--",
label="Q3",
alpha=0.5,
)
return fig # type: ignore
def make_plots(
results: dict[str, list[typing.Any]],
indexes: dict[int, str],
xlim: tuple[float, float] | None = None,
) -> list[matplotlib.figure.Figure]:
"""Plot histograms for a particular variable, across all datasets.
Parameters
----------
results
The results to be plotted.
indexes
A dictionary where keys are indexes in each sample of ``results``, and
values are a (possibly abbreviated) name to be used in figure titles
and axes.
xlim
Limits for histogram plotting.
Returns
-------
Matplotlib figures containing histograms for each dataset within
``results`` and named variables in ``indexes``.
"""
retval = []
for k, v in results.items():
samples = [s for s in v if len(s) != 2]
for idx, name in indexes.items():
val = numpy.array([s[idx] for s in samples])
retval.append(
make_histogram(
name, val, xlim=xlim, title=f"{name} Frequency Histogram (@ {k})"
)
)
return retval
......@@ -51,7 +51,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
"--input-folder",
"-i",
help="""Path from where to load saliency maps. You can generate saliency
maps with ``mednet saliency generate``.""",
maps with ``mednet classify saliency generate``.""",
required=True,
type=click.Path(
exists=True,
......@@ -89,6 +89,23 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
show_default=True,
cls=ResourceOption,
)
@click.option(
"--only-dataset",
"-S",
help="""If set, will only run the command for the named dataset on the
provided datamodule, skipping any other dataset.""",
cls=ResourceOption,
)
@click.option(
"--plot/--no-plot",
"-P",
help="""If set, then also produces figures containing the plots of
performance curves and score histograms.""",
required=True,
show_default=True,
default=True,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def interpretability(
model,
......@@ -96,6 +113,8 @@ def interpretability(
input_folder,
target_label,
output_folder,
only_dataset,
plot: bool,
**_,
) -> None: # numpydoc ignore=PR01
"""Evaluate saliency map agreement with annotations (human
......@@ -118,17 +137,21 @@ def interpretability(
with annotations (based on [SCORECAM-2020]_). It estimates how much
activation lies within the ground truth boxes compared to the total sum
of the activations.
* Average Saliency Focus: estimates how much of the ground truth bounding
boxes area is covered by the activations. It is similar to the
proportional energy measure in the sense that it does not need explicit
thresholding.
"""
import matplotlib.backends.backend_pdf
from ....scripts.utils import save_json_metadata, save_json_with_backup
from ...engine.saliency.interpretability import run
from ...engine.saliency.utils import make_plots, make_table
datamodule.model_transforms = list(model.model_transforms)
datamodule.batch_size = 1
datamodule.model_transforms = model.transforms
datamodule.prepare_data()
datamodule.setup(stage="predict")
......@@ -142,9 +165,34 @@ def interpretability(
output_json=output_json,
input_folder=input_folder,
target_label=target_label,
only_dataset=only_dataset,
plot=plot,
)
results = run(input_folder, target_label, datamodule)
results = run(
input_folder=input_folder,
target_label=target_label,
datamodule=datamodule,
only_dataset=only_dataset,
)
logger.info(f"Saving output file to `{str(output_json)}`...")
save_json_with_backup(output_json, results)
table = make_table(results=results, indexes={2: "PE", 3: "ASF"}, format_="rst")
output_table = output_json.with_suffix(".rst")
logger.info(f"Saving output summary table to `{str(output_table)}`...")
with output_table.open("w") as f:
f.write(table)
click.echo(table)
# Plots histograms, if the user asked to do so.
if plot:
figure_path = output_json.with_suffix(".pdf")
logger.info(f"Saving plots to `{str(figure_path)}`...")
with matplotlib.backends.backend_pdf.PdfPages(figure_path) as pdf:
for fig in make_plots(
results=results,
indexes={2: "Proportional Energy", 3: "Average Saliency Focus"},
):
pdf.savefig(fig)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment