Skip to content
Snippets Groups Projects
Commit 2cbd042c authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'issue-64' into 'main'

Refactor evaluation

Closes #64

See merge request biosignal/software/mednet!26
parents a6de3201 48f2fd28
No related branches found
No related tags found
1 merge request!26Refactor evaluation
Pipeline #84768 canceled
......@@ -2,5 +2,4 @@
# Project: extras
# Version: stable
# The remainder of this file is compressed using zlib.
xEA0 ;Wbq (ĖmS*Yg󛌸EZ-2T^%dw@# GAcZTiEt=qI
9ȊǷݱi>T>
xM=0 ཧk!,4 MPGA*'MUfY~!IFI䩋0;3uJc 8%uD@fYv0Ng'FӺ>l R&Ey]/ea4_ ͬ+|Q㤪R:;ma,
\ No newline at end of file
......@@ -4,3 +4,4 @@
# The remainder of this file is compressed using zlib.
torchvision.transforms py:module 1 https://pytorch.org/vision/stable/transforms.html -
optimizer_step py:method 1 api/lightning.pytorch.core.LightningModule.html#$ -
json.encoder.JSONEncoder py:class 1 https://docs.python.org/3/library/json.html#json.JSONEncoder -
......@@ -94,6 +94,7 @@ montgomery-f9 = "mednet.config.data.montgomery.fold_9"
# shenzhen dataset (and cross-validation folds)
shenzhen = "mednet.config.data.shenzhen.default"
shenzhen-alltest = "mednet.config.data.shenzhen.alltest"
shenzhen-f0 = "mednet.config.data.shenzhen.fold_0"
shenzhen-f1 = "mednet.config.data.shenzhen.fold_1"
shenzhen-f2 = "mednet.config.data.shenzhen.fold_2"
......
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Default Shenzhen TB database split.
Database reference: [MONTGOMERY-SHENZHEN-2014]_
* Test samples: 100% of the database
See :py:class:`mednet.config.data.shenzhen.datamodule.DataModule` for
technical details.
"""
from mednet.config.data.shenzhen.datamodule import DataModule
datamodule = DataModule("alltest.json")
......@@ -5,6 +5,7 @@
import contextlib
import itertools
import json
import logging
import typing
......@@ -113,24 +114,25 @@ def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float:
return maxf1_threshold
def _score_plot(
labels: numpy.typing.NDArray,
scores: numpy.typing.NDArray,
def score_plot(
histograms: dict[str, dict[str, numpy.typing.NDArray]],
title: str,
threshold: float,
threshold: float | None,
) -> matplotlib.figure.Figure:
"""Plot the normalized score distributions for all systems.
Parameters
----------
labels
True labels (negatives and positives) for each entry in ``scores``.
scores
Likelihoods provided by the classification model, for each sample.
histograms
A dictionary containing all histograms that should be inserted into the
plot. Each histogram should itself be setup as another dictionary
containing the keys ``hist`` and ``bin_edges`` as returned by
:py:func:`numpy.histogram`.
title
Title of the plot.
threshold
Shows where the threshold is in the figure.
Shows where the threshold is in the figure. If set to ``None``, then
does not show the threshold line.
Returns
-------
......@@ -138,43 +140,55 @@ def _score_plot(
A single (matplotlib) plot containing the score distribution, ready to
be saved to disk or displayed.
"""
from matplotlib.ticker import MaxNLocator
fig, ax = plt.subplots(1, 1)
assert isinstance(fig, matplotlib.figure.Figure)
ax = typing.cast(plt.Axes, ax) # gets editor to behave
# Here, we configure the "style" of our plot
ax.set_xlim([0, 1])
ax.set_xlim((0, 1))
ax.set_title(title)
ax.set_xlabel("Score")
ax.set_ylabel("Normalized count")
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
ax.set_ylabel("Count")
# Only show ticks on the left and bottom spines
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.get_yaxis().set_major_locator(MaxNLocator(integer=True))
positives = scores[labels > 0.5]
negatives = scores[labels < 0.5]
ax.hist(positives, bins="auto", label="positives", density=True, alpha=0.7)
ax.hist(negatives, bins="auto", label="negatives", density=True, alpha=0.7)
# Adds threshold line (dotted red)
ax.axvline(
threshold, # type: ignore
color="red",
lw=2,
alpha=0.75,
ls="dotted",
label="threshold",
)
# Setup the grid
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
ax.get_xaxis().grid(False)
max_hist = 0
for name in histograms.keys():
hist = histograms[name]["hist"]
bin_edges = histograms[name]["bin_edges"]
width = 0.7 * (bin_edges[1] - bin_edges[0])
center = (bin_edges[:-1] + bin_edges[1:]) / 2
ax.bar(center, hist, align="center", width=width, label=name, alpha=0.7)
max_hist = max(max_hist, hist.max())
# Detach axes from the plot
ax.spines["left"].set_position(("data", -0.015))
ax.spines["bottom"].set_position(("data", -0.015 * max_hist))
if threshold is not None:
# Adds threshold line (dotted red)
ax.axvline(
threshold, # type: ignore
color="red",
lw=2,
alpha=0.75,
ls="dotted",
label="threshold",
)
# Adds a nice legend
ax.legend(
title="Max F1-scores",
fancybox=True,
framealpha=0.7,
)
......@@ -188,12 +202,9 @@ def _score_plot(
def run_binary(
name: str,
predictions: Iterable[BinaryPrediction],
binning: str | int,
threshold_a_priori: float | None = None,
) -> tuple[
dict[str, typing.Any],
dict[str, matplotlib.figure.Figure],
dict[str, typing.Any],
]:
) -> dict[str, typing.Any]:
"""Run inference and calculates measures for binary classification.
Parameters
......@@ -202,6 +213,10 @@ def run_binary(
The name of subset to load.
predictions
A list of predictions to consider for measurement.
binning
The binning algorithm to use for computing the bin widths and
distribution for histograms. Choose from algorithms supported by
:py:func:`numpy.histogram`.
threshold_a_priori
A threshold to use, evaluated *a priori*, if must report single values.
If this value is not provided, an *a posteriori* threshold is calculated
......@@ -209,17 +224,13 @@ def run_binary(
Returns
-------
tuple[
dict[str, typing.Any],
dict[str, matplotlib.figure.Figure],
dict[str, typing.Any]]
dict[str, typing.Any]
A tuple containing the following entries:
* summary: A dictionary containing the performance summary on the
specified threshold.
* figures: A dictionary of generated standalone figures.
* curves: A dictionary containing curves that can potentially be combined
with other prediction lists to make aggregate plots.
specified threshold, general performance curves (under the key
``curves``), and score histograms (under the key
``score-histograms``).
"""
y_scores = numpy.array([k[2] for k in predictions]) # likelihoods
......@@ -253,40 +264,66 @@ def run_binary(
y_labels, y_predictions, pos_label=pos_label
),
average_precision_score=sklearn.metrics.average_precision_score(
y_labels, y_predictions, pos_label=pos_label
y_labels, y_scores, pos_label=pos_label
),
specificity=sklearn.metrics.recall_score(
y_labels, y_predictions, pos_label=neg_label
),
auc_score=sklearn.metrics.roc_auc_score(
y_labels,
y_predictions,
y_scores,
),
accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions),
)
# figures: score distributions
figures = dict(
scores=_score_plot(
y_labels,
y_scores,
f"Score distribution (split: {name})",
use_threshold,
# curves: ROC and precision recall
summary["curves"] = dict(
roc=dict(
zip(
("fpr", "tpr", "thresholds"),
sklearn.metrics.roc_curve(
y_labels, y_scores, pos_label=pos_label
),
)
),
precision_recall=dict(
zip(
("precision", "recall", "thresholds"),
sklearn.metrics.precision_recall_curve(
y_labels, y_scores, pos_label=pos_label
),
),
),
)
# curves: ROC and precision recall
curves = dict(
roc=sklearn.metrics.roc_curve(y_labels, y_scores, pos_label=pos_label),
precision_recall=sklearn.metrics.precision_recall_curve(
y_labels, y_scores, pos_label=pos_label
# score histograms
# what works: <integer>, doane*, scott, stone, rice*, sturges*, sqrt
# what does not work: auto, fd
summary["score-histograms"] = dict(
positives=dict(
zip(
("hist", "bin_edges"),
numpy.histogram(
y_scores[y_labels == pos_label], bins=binning, range=(0, 1)
),
)
),
negatives=dict(
zip(
("hist", "bin_edges"),
numpy.histogram(
y_scores[y_labels == neg_label],
bins=binning,
range=(0, 1),
),
)
),
)
return summary, figures, curves
return summary
def aggregate_summaries(
def tabulate_results(
data: typing.Mapping[str, typing.Mapping[str, typing.Any]], fmt: str
) -> str:
"""Tabulate summaries from multiple splits.
......@@ -379,14 +416,16 @@ def aggregate_roc(
legend = []
for name, (fpr, tpr, _) in data.items():
for name, elements in data.items():
# plots roc curve
_auc = sklearn.metrics.auc(fpr, tpr)
_auc = sklearn.metrics.auc(elements["fpr"], elements["tpr"])
label = f"{name} (AUC={_auc:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = ax.plot(fpr, tpr, color=color, linestyle=style)
(line,) = ax.plot(
elements["fpr"], elements["tpr"], color=color, linestyle=style
)
legend.append((line, label))
if len(legend) > 1:
......@@ -516,13 +555,20 @@ def aggregate_pr(
axes.set_title(title)
legend = []
for name, (prec, recall, _) in data.items():
_ap = credible.curves.average_metric([prec, recall])
for name, elements in data.items():
_ap = credible.curves.average_metric(
(elements["precision"], elements["recall"])
)
label = f"{name} (AP={_ap:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = axes.plot(recall, prec, color=color, linestyle=style)
(line,) = axes.plot(
elements["recall"],
elements["precision"],
color=color,
linestyle=style,
)
legend.append((line, label))
if len(legend) > 1:
......@@ -535,3 +581,31 @@ def aggregate_pr(
)
return fig
class NumpyJSONEncoder(json.JSONEncoder):
"""Extends the standard JSON encoder to support Numpy arrays."""
def default(self, o: typing.Any) -> typing.Any:
"""If input object is a ndarray it will be converted into a list.
Parameters
----------
o
Input object to be JSON serialized.
Returns
-------
A serializable representation of object ``o``.
"""
if isinstance(o, numpy.ndarray):
try:
retval = o.tolist()
except TypeError:
pass
else:
return retval
# Let the base class default method raise the TypeError
return super().default(o)
......@@ -70,16 +70,41 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
the highest F1-score on that set) and then applied to the subsequent
sets. This value is not used for multi-class classification tasks.""",
default=0.5,
show_default=False,
show_default=True,
required=True,
type=click.STRING,
cls=ResourceOption,
)
@click.option(
"--binning",
"-b",
help="""The binning algorithm to use for computing the bin widths and
distribution for histograms. Choose from algorithms supported by
:py:func:`numpy.histogram`, or a simple integer indicating the number of
bins to have in the interval ``[0, 1]``.""",
default="50",
show_default=True,
required=True,
type=click.STRING,
cls=ResourceOption,
)
@click.option(
"--plot/--no-plot",
"-P",
help="""If set, then also produces figures containing the plots of
performance curves and score histograms.""",
required=True,
show_default=True,
default=True,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def evaluate(
predictions: pathlib.Path,
output: pathlib.Path,
threshold: str | float,
binning: str,
plot: bool,
**_, # ignored
) -> None: # numpydoc ignore=PR01
"""Evaluate predictions (from a model) on a classification task."""
......@@ -87,15 +112,15 @@ def evaluate(
import json
import typing
import matplotlib.figure
from matplotlib.backends.backend_pdf import PdfPages
from ..engine.evaluator import (
NumpyJSONEncoder,
aggregate_pr,
aggregate_roc,
aggregate_summaries,
run_binary,
score_plot,
tabulate_results,
)
from .utils import execution_metadata, save_json_with_backup
......@@ -126,29 +151,30 @@ def evaluate(
or can not be converted to a float. Check your input."""
)
results: dict[
str,
tuple[
dict[str, typing.Any],
dict[str, matplotlib.figure.Figure],
dict[str, typing.Any],
],
] = dict()
results: dict[str, dict[str, typing.Any]] = dict()
for k, v in predict_data.items():
logger.info(f"Analyzing split `{k}`...")
results[k] = run_binary(
name=k,
predictions=v,
binning=int(binning) if binning.isnumeric() else binning,
threshold_a_priori=use_threshold,
)
data = {k: v[0] for k, v in results.items()}
# records full result analysis to a JSON file
logger.info(f"Saving evaluation results at `{output}`...")
with output.open("w") as f:
json.dump(data, f, indent=2)
json.dump(results, f, indent=2, cls=NumpyJSONEncoder)
# dump evaluation results in RST format to screen and file
table = aggregate_summaries(data, fmt="rst")
table_data = {}
for k, v in results.items():
table_data[k] = {
kk: vv
for kk, vv in v.items()
if kk not in ("curves", "score-histograms")
}
table = tabulate_results(table_data, fmt="rst")
click.echo(table)
table_path = output.with_suffix(".rst")
......@@ -162,20 +188,23 @@ def evaluate(
figure_path = output.with_suffix(".pdf")
logger.info(f"Saving evaluation figures at `{figure_path}`...")
with PdfPages(figure_path) as pdf:
pr_curves = {k: v[2]["precision_recall"] for k, v in results.items()}
pr_fig = aggregate_pr(pr_curves)
pdf.savefig(pr_fig)
roc_curves = {k: v[2]["roc"] for k, v in results.items()}
roc_fig = aggregate_roc(roc_curves)
pdf.savefig(roc_fig)
# order ready-to-save figures by type instead of split
figures = {k: v[1] for k, v in results.items()}
keys = next(iter(figures.values())).keys()
figures_by_type = {k: [v[k] for v in figures.values()] for k in keys}
for group_figures in figures_by_type.values():
for f in group_figures:
pdf.savefig(f)
if plot:
with PdfPages(figure_path) as pdf:
pr_curves = {
k: v["curves"]["precision_recall"] for k, v in results.items()
}
pr_fig = aggregate_pr(pr_curves)
pdf.savefig(pr_fig)
roc_curves = {k: v["curves"]["roc"] for k, v in results.items()}
roc_fig = aggregate_roc(roc_curves)
pdf.savefig(roc_fig)
# score plots
for k, v in results.items():
score_fig = score_plot(
v["score-histograms"],
f"Score distribution (split: {k})",
v["threshold"],
)
pdf.savefig(score_fig)
......@@ -26,3 +26,109 @@ def test_centered_maxf1():
assert maxf1 == 1.0
assert threshold == 0.4
def test_run_binary_1():
from mednet.engine.evaluator import run_binary
from mednet.models.typing import BinaryPrediction
predictions: list[BinaryPrediction] = [
# (name, target, predicted-value)
("s0", 0, 0.1),
("s2", 0, 0.8),
("s3", 1, 0.9),
("s3", 1, 0.4),
]
results = run_binary(
"test", predictions, binning=10, threshold_a_priori=0.5
)
assert results["num_samples"] == 4
assert numpy.isclose(results["threshold"], 0.5)
assert not results["threshold_a_posteriori"]
assert numpy.isclose(results["precision"], 1 / 2) # tp / (tp + fp)
assert numpy.isclose(results["recall"], 1 / 2) # tp / (tp + fn)
assert numpy.isclose(
results["f1_score"], 2 * (1 / 2 * 1 / 2) / (1 / 2 + 1 / 2)
) # 2 * (prec. * recall) / (prec. + recall)
assert numpy.isclose(
results["accuracy"], (1 + 1) / (1 + 1 + 1 + 1)
) # (tp + tn) / (tp + fn + tn + fp)
assert numpy.isclose(results["specificity"], 1 / 2) # tn / (tn + fp)
# threshold table:
# threshold | TNR | 1-TNR | TPR
# ----------+-------+-------+---------
# < 0.1 | 0 | 1 | 1
# 0.1 | 0.5 | 0.5 | 1
# 0.4 | 0.5 | 0.5 | 0.5
# 0.8 | 1 | 0 | 0.5
# 0.9 | 1 | 0 | 0
# > 0.9 | 1 | 0 | 0
assert numpy.isclose(results["auc_score"], 0.75)
# threshold table:
# threshold | Prec. | Recall
# ----------+---------+----------
# < 0.1 | 0.5 | 1
# 0.1 | 2/3 | 1
# 0.4 | 0.5 | 0.5
# 0.8 | 1 | 0.5
# 0.9 | 0 | 0
# > 0.9 | 0 | 0
assert numpy.isclose(results["average_precision_score"], 0.8333333)
def test_run_binary_2():
from mednet.engine.evaluator import run_binary
from mednet.models.typing import BinaryPrediction
predictions: list[BinaryPrediction] = [
# (name, target, predicted-value)
("s0", 0, 0.1),
("s2", 0, 0.8),
("s3", 1, 0.9),
("s3", 1, 0.4),
]
# a change in the threshold should not affect auc and average precision scores
results = run_binary(
"test", predictions, binning=10, threshold_a_priori=0.3
)
assert results["num_samples"] == 4
assert numpy.isclose(results["threshold"], 0.3)
assert not results["threshold_a_posteriori"]
assert numpy.isclose(results["precision"], 2 / 3) # tp / (tp + fp)
assert numpy.isclose(results["recall"], 2 / 2) # tp / (tp + fn)
assert numpy.isclose(
results["f1_score"], 2 * (2 / 3 * 2 / 2) / (2 / 3 + 2 / 2)
) # 2 * (prec. * recall) / (prec. + recall)
assert numpy.isclose(
results["accuracy"], (2 + 1) / (2 + 0 + 1 + 1)
) # (tp + tn) / (tp + fn + tn + fp)
assert numpy.isclose(results["specificity"], 1 / (1 + 1)) # tn / (tn + fp)
# threshold table:
# threshold | TNR | 1-TNR | TPR
# ----------+-------+-------+---------
# < 0.1 | 0 | 1 | 1
# 0.1 | 0.5 | 0.5 | 1
# 0.4 | 0.5 | 0.5 | 0.5
# 0.8 | 1 | 0 | 0.5
# 0.9 | 1 | 0 | 0
# > 0.9 | 1 | 0 | 0
assert numpy.isclose(results["auc_score"], 0.75)
# threshold table:
# threshold | Prec. | Recall
# ----------+---------+----------
# < 0.1 | 0.5 | 1
# 0.1 | 2/3 | 1
# 0.4 | 0.5 | 0.5
# 0.8 | 1 | 0.5
# 0.9 | 0 | 0
# > 0.9 | 0 | 0
assert numpy.isclose(results["average_precision_score"], 0.8333333)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment