Skip to content
Snippets Groups Projects
Commit a35a5a9a authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[scripts.evaluate] Complete refactor

parent 48990aea
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Showing
with 569 additions and 1765 deletions
......@@ -3,367 +3,516 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines functionality for the evaluation of predictions."""
import contextlib
import itertools
import logging
import os
import re
import typing
from collections.abc import Iterable
from typing import Optional
from collections.abc import Iterable, Iterator
import matplotlib.pyplot as plt
import matplotlib.figure
import numpy
import pandas as pd
import torch
import numpy.typing
import sklearn.metrics
import tabulate
from sklearn import metrics
from matplotlib import pyplot as plt
from ..utils.measure import base_measures, get_centered_maxf1
from ..models.typing import BinaryPrediction
logger = logging.getLogger(__name__)
def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float:
"""Evaluates the EER threshold from negative and positive scores.
def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float:
"""Calculates the (approximate) threshold leading to the equal error rate.
Parameters
----------
predictions
An iterable of multiple
:py:data:`ptbench.models.typing.BinaryPrediction`'s.
neg :
Negative scores
pos :
Positive scores
Returns:
Returns
-------
The EER threshold value.
"""
from scipy.interpolate import interp1d
from scipy.optimize import brentq
y_predictions = pd.concat((neg, pos))
y_true = numpy.concatenate((numpy.zeros_like(neg), numpy.ones_like(pos)))
y_scores = [k[2] for k in predictions]
y_labels = [k[1] for k in predictions]
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1)
fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_labels, y_scores)
eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
return float(interp1d(fpr, thresholds)(eer))
def posneg(
pred, gt, threshold
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculates true and false positives and negatives.
def _get_centered_maxf1(
f1_scores: numpy.typing.NDArray, thresholds: numpy.typing.NDArray
):
"""Return the centered max F1 score threshold when multiple thresholds give
the same max F1 score.
Parameters
----------
f1_scores
1D array of f1 scores
thresholds
1D array of thresholds
pred :
Pixel-wise predictions.
Returns
-------
A tuple with the maximum F1-score and the "centered" threshold.
"""
maxf1 = f1_scores.max()
maxf1_indices = numpy.where(f1_scores == maxf1)[0]
gt :
Ground-truth (annotations).
# If multiple thresholds give the same max F1 score
if len(maxf1_indices) > 1:
mean_maxf1_index = int(round(numpy.mean(maxf1_indices)))
else:
mean_maxf1_index = maxf1_indices[0]
return maxf1, thresholds[mean_maxf1_index]
def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float:
"""Calculates the threshold leading to the maximum F1-score on a precision-
recall curve.
Parameters
----------
predictions
An iterable of multiple
:py:data:`ptbench.models.typing.BinaryPrediction`'s.
threshold :
A particular threshold in which to calculate the performance
measures.
Returns
-------
The threshold value leading to the maximum F1-score on the provided set
of predictions.
"""
y_scores = [k[2] for k in predictions]
y_labels = [k[1] for k in predictions]
tp_tensor:
The true positive values.
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
y_labels, y_scores
)
fp_tensor:
The false positive values.
numerator = 2 * recall * precision
denom = recall + precision
f1_scores = numpy.divide(
numerator, denom, out=numpy.zeros_like(denom), where=(denom != 0)
)
tn_tensor:
The true negative values.
_, maxf1_threshold = _get_centered_maxf1(f1_scores, thresholds)
return maxf1_threshold
fn_tensor:
The false negative values.
"""
# threshold
binary_pred = torch.gt(pred, threshold)
def _score_plot(
labels: numpy.typing.NDArray,
scores: numpy.typing.NDArray,
title: str,
threshold: float,
) -> matplotlib.figure.Figure:
"""Plots the normalized score distributions for all systems.
# equals and not-equals
equals = torch.eq(binary_pred, gt).type(torch.uint8)
notequals = torch.ne(binary_pred, gt).type(torch.uint8)
Parameters
----------
labels
True labels (negatives and positives) for each entry in ``scores``
scores
Likelihoods provided by the classification model, for each sample
title
Title of the plot.
threshold
Shows where the threshold is in the figure
# true positives
tp_tensor = (gt * binary_pred).type(torch.uint8)
# false positives
fp_tensor = torch.eq((binary_pred + tp_tensor), 1).type(torch.uint8)
Returns
-------
A single (matplotlib) plot containing the score distribution, ready to
be saved to disk or displayed.
"""
fig, ax = plt.subplots(1, 1)
assert isinstance(fig, matplotlib.figure.Figure)
ax = typing.cast(plt.Axes, ax) # gets editor to behave
# Here, we configure the "style" of our plot
ax.set_xlim([0, 1])
ax.set_title(title)
ax.set_xlabel("Score")
ax.set_ylabel("Normalized count")
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
# Only show ticks on the left and bottom spines
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
positives = scores[labels > 0.5]
negatives = scores[labels < 0.5]
ax.hist(positives, bins="auto", label="positives", density=True, alpha=0.7)
ax.hist(negatives, bins="auto", label="negatives", density=True, alpha=0.7)
# Adds threshold line (dotted red)
ax.axvline(
threshold, # type: ignore
color="red",
lw=2,
alpha=0.75,
ls="dotted",
label="threshold",
)
# true negatives
tn_tensor = (equals - tp_tensor).type(torch.uint8)
# Adds a nice legend
ax.legend(
title="Max F1-scores",
fancybox=True,
framealpha=0.7,
)
# false negatives
fn_tensor = notequals - fp_tensor.type(torch.uint8)
# Makes sure the figure occupies most of the possible space
fig.tight_layout()
return tp_tensor, fp_tensor, tn_tensor, fn_tensor
return fig
def sample_measures_for_threshold(
pred: torch.Tensor, gt: torch.Tensor, threshold: float
) -> tuple[float, float, float, float, float]:
"""Calculates measures on one single sample, for a specific threshold.
def run_binary(
name: str,
predictions: Iterable[BinaryPrediction],
threshold_a_priori: float | None = None,
) -> tuple[
dict[str, typing.Any],
dict[str, matplotlib.figure.Figure],
dict[str, typing.Any],
]:
"""Runs inference and calculates measures for binary classification.
Parameters
----------
name
The name of subset to load.
predictions
A list of predictions to consider for measurement
threshold_a_priori
A threshold to use, evaluated *a priori*, if must report single values.
If this value is not provided, a *a posteriori* threshold is calculated
on the input scores. This is a biased estimator.
pred :
Pixel-wise predictions.
gt :
Ground-truth (annotations).
threshold :
A particular threshold in which to calculate the performance
measures.
Returns
-------
A tuple containing the following entries:
precision : float
P, AKA positive predictive value (PPV). It corresponds arithmetically
to ``tp/(tp+fp)``. In the case ``tp+fp == 0``, this function returns
zero for precision.
recall : float
R, AKA sensitivity, hit rate, or true positive rate (TPR). It
corresponds arithmetically to ``tp/(tp+fn)``. In the special case
where ``tp+fn == 0``, this function returns zero for recall.
specificity : float
S, AKA selectivity or true negative rate (TNR). It
corresponds arithmetically to ``tn/(tn+fp)``. In the special case
where ``tn+fp == 0``, this function returns zero for specificity.
accuracy : float
A, see `Accuracy
<https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is
the proportion of correct predictions (both true positives and true
negatives) among the total number of pixels examined. It corresponds
arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes
both true-negatives and positives in the numerator, what makes it
sensitive to data or regions without annotations.
jaccard : float
J, see `Jaccard Index or Similarity
<https://en.wikipedia.org/wiki/Jaccard_index>`_. It corresponds
arithmetically to ``tp/(tp+fp+fn)``. In the special case where
``tn+fp+fn == 0``, this function returns zero for the Jaccard index.
The Jaccard index depends on a TP-only numerator, similarly to the F1
score. For regions where there are no annotations, the Jaccard index
will always be zero, irrespective of the model output. Accuracy may be
a better proxy if one needs to consider the true abscence of
annotations in a region as part of the measure.
f1_score : float
F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_. It
corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``.
In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function
returns zero for the Jaccard index. The F1 or Dice score depends on a
TP-only numerator, similarly to the Jaccard index. For regions where
there are no annotations, the F1-score will always be zero,
irrespective of the model output. Accuracy may be a better proxy if
one needs to consider the true abscence of annotations in a region as
part of the measure.
* summary: A dictionary containing the performance summary on the
specified threshold
* figures: A dictionary of generated standalone figures
* curves: A dictionary containing curves that can potentially be combined
with other prediction lists to make aggregate plots.
"""
tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)
y_scores = numpy.array([k[2] for k in predictions]) # likelihoods
y_labels = numpy.array([k[1] for k in predictions]) # integers
neg_label = y_labels.min()
pos_label = y_labels.max()
use_threshold = threshold_a_priori
if use_threshold is None:
use_threshold = maxf1_threshold(predictions)
logger.warning(
f"User did not pass an *a priori* threshold for the evaluation "
f"of split `{name}`. Using threshold a posteriori (biased) with value "
f"`{use_threshold:.4f}`"
)
# calc measures from scalars
tp_count = torch.sum(tp_tensor).item()
fp_count = torch.sum(fp_tensor).item()
tn_count = torch.sum(tn_tensor).item()
fn_count = torch.sum(fn_tensor).item()
return base_measures(tp_count, fp_count, tn_count, fn_count)
y_predictions = numpy.where(y_scores >= use_threshold, pos_label, neg_label)
# point measures on threshold
summary = dict(
split=name,
threshold=use_threshold,
threshold_a_posteriori=(threshold_a_priori is None),
precision=sklearn.metrics.precision_score(
y_labels, y_predictions, pos_label=pos_label
),
recall=sklearn.metrics.recall_score(
y_labels, y_predictions, pos_label=pos_label
),
specificity=sklearn.metrics.recall_score(
y_labels, y_predictions, pos_label=neg_label
),
accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions),
f1_score=sklearn.metrics.f1_score(
y_labels, y_predictions, pos_label=pos_label
),
)
def run(
name: str,
predictions_folder: str,
f1_thresh: Optional[float] = None,
eer_thresh: Optional[float] = None,
steps: Optional[int] = 1000,
):
"""Runs inference and calculates measures.
# figures: score distributions
figures = dict(
scores=_score_plot(
y_labels,
y_scores,
f"Score distribution (split: {name})",
use_threshold,
),
)
Parameters
---------
# curves: ROC and precision recall
curves = dict(
roc=sklearn.metrics.roc_curve(y_labels, y_scores, pos_label=pos_label),
precision_recall=sklearn.metrics.precision_recall_curve(
y_labels, y_scores, pos_label=pos_label
),
)
name:
The name of subset to load.
return summary, figures, curves
predictions_folder:
Folder where predictions for the dataset images has been previously
stored.
f1_thresh:
This number should come from
the training set or a separate validation set. Using a test set value
may bias your analysis. This number is also used to print the a priori
F1-score on the evaluated set.
def aggregate_summaries(
data: typing.Sequence[typing.Mapping[str, typing.Any]], fmt: str
) -> str:
"""Tabulates summaries from multiple splits.
eer_thresh:
This number should come from
the training set or a separate validation set. Using a test set value
may bias your analysis. This number is used to print the a priori
EER.
This function can properly :py:mod:`tabulate` the various summaries
produced for all the splits in a prediction database.
steps:
number of threshold steps to consider when evaluating thresholds.
Parameters
----------
data
An iterable over all summary data collected
fmt
One of the formats supported by :py:mod:`tabulate`.
Returns
-------
A string containing the tabulated information
"""
headers = list(data[0].keys())
table = [[k[h] for h in headers] for k in data]
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
pred_data:
The loaded predictions for the specified subset.
def aggregate_roc(
data: typing.Mapping[str, typing.Any],
title: str = "ROC",
) -> matplotlib.figure.Figure:
"""Aggregates ROC curves from multiple splits.
fig_scores:
Figure of the histogram distributions of true-positive/true-negative scores.
This function produces a single ROC plot for multiple curves generated per
split.
maxf1_threshold:
Threshold to achieve the highest possible F1-score for this dataset.
post_eer_threshold:
Threshold achieving Equal Error Rate for this dataset.
Parameters
----------
data
A dictionary mapping split names to ROC curve data produced by
:py:func:sklearn.metrics.roc_curve`.
Returns
-------
A figure, containing the aggregated ROC plot.
"""
predictions_path = os.path.join(predictions_folder, f"{name}.csv")
if not os.path.exists(predictions_path):
predictions_path = predictions_folder
# Load predictions
pred_data = pd.read_csv(predictions_path)
pred = torch.Tensor(
[
eval(re.sub(" +", " ", x.replace("\n", "")).replace(" ", ","))
for x in pred_data["likelihood"].values
]
).double()
gt = torch.Tensor(
[
eval(re.sub(" +", " ", x.replace("\n", "")).replace(" ", ","))
for x in pred_data["ground_truth"].values
]
).double()
if pred.shape[1] == 1 and gt.shape[1] == 1:
pred = torch.flatten(pred)
gt = torch.flatten(gt)
pred_data["likelihood"] = pred
pred_data["ground_truth"] = gt
# Multiclass f1 score computation
if pred.ndim > 1:
auc = metrics.roc_auc_score(gt, pred)
logger.info("Evaluating multiclass classification")
logger.info(f"AUC: {auc}")
logger.info("F1 and EER are not implemented for multiclass")
return None, None
# Generate measures for each threshold
step_size = 1.0 / steps
data = [
(index, threshold) + sample_measures_for_threshold(pred, gt, threshold)
for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size))
fig, ax = plt.subplots(1, 1)
assert isinstance(fig, matplotlib.figure.Figure)
# Names and bounds
ax.set_xlabel("1 - specificity")
ax.set_ylabel("Sensitivity")
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.0])
ax.set_title(title)
# we should see some of ax 1 ax
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_position(("data", -0.015))
ax.spines["bottom"].set_position(("data", -0.015))
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
plt.tight_layout()
lines = ["-", "--", "-.", ":"]
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colorcycler = itertools.cycle(colors)
linecycler = itertools.cycle(lines)
legend = []
for name, (fpr, tpr, _) in data.items():
# plots roc curve
_auc = sklearn.metrics.auc(fpr, tpr)
label = f"{name} (AUC={_auc:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = ax.plot(fpr, tpr, color=color, linestyle=style)
legend.append((line, label))
if len(legend) > 1:
ax.legend(
[k[0] for k in legend],
[k[1] for k in legend],
loc="lower right",
fancybox=True,
framealpha=0.7,
)
data_df = pd.DataFrame(
data,
columns=(
"index",
"threshold",
"precision",
"recall",
"specificity",
"accuracy",
"jaccard",
"f1_score",
),
)
data_df = data_df.set_index("index")
"""# Save evaluation csv
if output_folder is not None:
fullpath = os.path.join(output_folder, f"{name}.csv")
logger.info(f"Saving {fullpath}...")
os.makedirs(os.path.dirname(fullpath), exist_ok=True)
data_df.to_csv(fullpath)"""
# Find max F1 score
f1_scores = numpy.asarray(data_df["f1_score"])
thresholds = numpy.asarray(data_df["threshold"])
maxf1, maxf1_threshold = get_centered_maxf1(f1_scores, thresholds)
logger.info(
f"Maximum F1-score of {maxf1:.5f}, achieved at "
f"threshold {maxf1_threshold:.3f} (chosen *a posteriori*)"
)
return fig
# Find EER
neg_gt = pred_data.loc[pred_data.loc[:, "ground_truth"] == 0, :]
pos_gt = pred_data.loc[pred_data.loc[:, "ground_truth"] == 1, :]
post_eer_threshold = eer_threshold(
neg_gt["likelihood"], pos_gt["likelihood"]
)
logger.info(
f"Equal error rate achieved at "
f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)"
)
@contextlib.contextmanager
def _precision_recall_canvas() -> (
Iterator[tuple[matplotlib.figure.Figure, matplotlib.figure.Axes]]
):
"""Generates a canvas to draw precision-recall curves.
# Generate scores fig
fig_score, axes = plt.subplots(1)
fig_score.tight_layout(pad=3.0)
Works like a context manager, yielding a figure and an axes set in which
the precision-recall curves should be added to. The figure already
contains F1-ISO lines and is preset to a 0-1 square region. Once the
context is finished, ``fig.tight_layout()`` is called.
# Names and bounds
axes.set_xlabel("Score")
axes.set_ylabel("Normalized counts")
axes.set_xlim(0.0, 1.0)
neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len(
pred_data["likelihood"]
)
pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len(
pred_data["likelihood"]
)
Yields
------
figure
The figure that should be finally returned to the user
axes
An axis set where to precision-recall plots should be added to
"""
axes.hist(
[neg_gt["likelihood"], pos_gt["likelihood"]],
weights=[neg_weights, pos_weights],
bins=100,
color=["tab:blue", "tab:orange"],
label=["Negatives", "Positives"],
)
axes.legend(prop={"size": 10}, loc="upper center")
axes.set_title(f"Score table for {name} subset")
fig, axes1 = plt.subplots(1)
assert isinstance(fig, matplotlib.figure.Figure)
assert isinstance(axes1, matplotlib.figure.Axes)
# Names and bounds
axes1.set_xlabel("Recall")
axes1.set_ylabel("Precision")
axes1.set_xlim([0.0, 1.0])
axes1.set_ylim([0.0, 1.0])
axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
axes2 = axes1.twinx()
# Annotates plot with F1-score iso-lines
f_scores = numpy.linspace(0.1, 0.9, num=9)
tick_locs = []
tick_labels = []
for f_score in f_scores:
x = numpy.linspace(0.01, 1)
y = f_score * x / (2 * x - f_score)
plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
tick_locs.append(y[-1])
tick_labels.append("%.1f" % f_score)
axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
axes2.set_ylabel("iso-F", color="green", alpha=0.3)
axes2.set_ylim([0.0, 1.0])
axes2.yaxis.set_label_coords(1.015, 0.97)
axes2.set_yticks(tick_locs) # notice these are invisible
for k in axes2.set_yticklabels(tick_labels):
k.set_color("green")
k.set_alpha(0.3)
k.set_size(8)
# we should see some of axes 1 axes
axes.spines["right"].set_visible(False)
axes.spines["top"].set_visible(False)
axes.spines["left"].set_position(("data", -0.015))
"""if f1_thresh is not None and eer_thresh is not None:
# get the closest possible threshold we have
index = int(round(steps * f1_thresh))
f1_a_priori = data_df["f1_score"][index]
actual_threshold = data_df["threshold"][index]
logger.info(
f"F1-score of {f1_a_priori:.5f}, at threshold "
f"{actual_threshold:.3f} (chosen *a priori*)"
)
axes1.spines["right"].set_visible(False)
axes1.spines["top"].set_visible(False)
axes1.spines["left"].set_position(("data", -0.015))
axes1.spines["bottom"].set_position(("data", -0.015))
# we shouldn't see any of axes 2 axes
axes2.spines["right"].set_visible(False)
axes2.spines["top"].set_visible(False)
axes2.spines["left"].set_visible(False)
axes2.spines["bottom"].set_visible(False)
# yield execution, lets user draw precision-recall plots, and the legend
# before tighteneing the layout
yield fig, axes1
plt.tight_layout()
# Print the a priori EER threshold
logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}")"""
def aggregate_pr(
data: typing.Mapping[str, typing.Any],
title: str = "Precision-Recall Curve",
) -> matplotlib.figure.Figure:
"""Aggregates PR curves from multiple splits.
return pred_data, fig_score, maxf1_threshold, post_eer_threshold
This function produces a single Precision-Recall plot for multiple curves
generated per split. The plot will be annotated with F1-score iso-lines (in
which the F1-score maintains the same value).
Parameters
----------
data
A dictionary mapping split names to ROC curve data produced by
:py:func:sklearn.metrics.precision_recall_curve`.
Returns
-------
A figure, containing the aggregated PR plot.
"""
lines = ["-", "--", "-.", ":"]
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colorcycler = itertools.cycle(colors)
linecycler = itertools.cycle(lines)
with _precision_recall_canvas() as (fig, axes):
axes.set_title(title)
legend = []
for name, (prec, recall, _) in data.items():
_auc = sklearn.metrics.auc(recall, prec)
label = f"{name} (AUC={_auc:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = axes.plot(recall, prec, color=color, linestyle=style)
legend.append((line, label))
if len(legend) > 1:
axes.legend(
[k[0] for k in legend],
[k[1] for k in legend],
loc="lower left",
fancybox=True,
framealpha=0.7,
)
return fig
......@@ -7,7 +7,12 @@ import logging
import lightning.pytorch
import torch.utils.data
from ..models.typing import Prediction
from ..models.typing import (
BinaryPrediction,
BinaryPredictionSplit,
MultiClassPrediction,
MultiClassPredictionSplit,
)
from .device import DeviceManager
logger = logging.getLogger(__name__)
......@@ -18,9 +23,12 @@ def run(
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
) -> (
list[Prediction]
| list[list[Prediction]]
| dict[str, list[Prediction]]
list[BinaryPrediction]
| list[MultiClassPrediction]
| list[list[BinaryPrediction]]
| list[list[MultiClassPrediction]]
| BinaryPredictionSplit
| MultiClassPredictionSplit
| None
):
"""Runs inference on input data, outputs csv files with predictions.
......
......@@ -205,7 +205,7 @@ class Alexnet(pl.LightningModule):
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
......
......@@ -199,7 +199,7 @@ class Densenet(pl.LightningModule):
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
......
......@@ -107,7 +107,7 @@ class LogisticRegression(pl.LightningModule):
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
......
......@@ -112,7 +112,7 @@ class MultiLayerPerceptron(pl.LightningModule):
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
......
......@@ -265,7 +265,7 @@ class Pasa(pl.LightningModule):
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
......
......@@ -8,10 +8,12 @@ import typing
import torch
from ..data.typing import Sample
from .typing import Prediction
from .typing import BinaryPrediction, MultiClassPrediction
def _as_predictions(samples: typing.Iterable[Sample]) -> list[Prediction]:
def _as_predictions(
samples: typing.Iterable[Sample],
) -> list[BinaryPrediction | MultiClassPrediction]:
"""Takes a list of separated batch predictions and transform into a list of
formal predictions.
......@@ -28,7 +30,7 @@ def _as_predictions(samples: typing.Iterable[Sample]) -> list[Prediction]:
return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples]
def separate(batch: Sample) -> list[Prediction]:
def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]:
"""Separates a collated batch reconstituting its samples.
This function implements the inverse of
......
......@@ -5,17 +5,23 @@
import typing
Checkpoint: typing.TypeAlias = typing.Mapping[str, typing.Any]
Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
"""Definition of a lightning checkpoint."""
BinaryPrediction: typing.TypeAlias = tuple[str, int, float]
"""Prediction: the sample name, the target, and the predicted value."""
Prediction: typing.TypeAlias = tuple[
str, int | typing.Sequence[int], float | typing.Sequence[float]
MultiClassPrediction: typing.TypeAlias = tuple[
str, typing.Sequence[int], typing.Sequence[float]
]
"""Prediction: the sample name, the target, and the predicted value."""
BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[BinaryPrediction]
]
"""A series of predictions for different database splits."""
PredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[Prediction]
MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
str, typing.Sequence[MultiClassPrediction]
]
"""A series of predictions for different database splits."""
......@@ -2,253 +2,168 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
from collections import defaultdict
import pathlib
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup
from matplotlib.backends.backend_pdf import PdfPages
from ..data.datamodule import ConcatDataModule
from ..data.typing import DataLoader
from ..utils.plot import precision_recall_f1iso, roc_curve
from ..utils.table import performance_table
from .click import ConfigCommand
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _validate_threshold(
threshold: int | float | str, dataloader_dict: dict[str, DataLoader]
):
"""Validates the user threshold selection.
Parameters
----------
threshold:
This number is used to define positives and negatives from
probability maps, and report F1-scores (a priori). It
should either come from the training set or a separate validation set
to avoid biasing the analysis. Optionally, if you provide a multi-set
dataset as input, this may also be the name of an existing set from
which the threshold will be estimated (highest F1-score) and then
applied to the subsequent sets. This number is also used to print
the test set F1-score a priori performance
dataloader_dict:
Dictionary of set_name: dataloader, there set_name is the name of a dataset split
and dataloader is the torch dataloader for that split.
Returns
-------
The parsed threshold.
"""
if threshold is None:
return 0.5
try:
# we try to convert it to float first
threshold = float(threshold)
if threshold < 0.0 or threshold > 1.0:
raise ValueError("Float thresholds must be within range [0.0, 1.0]")
except ValueError:
# it is a bit of text - assert dataset with name is available
if not isinstance(dataloader_dict, dict):
raise ValueError(
"Threshold should be a floating-point number "
"if your provide only a single dataset for evaluation"
)
if threshold not in dataloader_dict:
raise ValueError(
f"Text thresholds should match dataset names, "
f"but {threshold} is not available among the datasets provided ("
f"({', '.join(dataloader_dict.keys())})"
)
return threshold
@click.command(
entry_point_group="ptbench.config",
cls=ConfigCommand,
epilog="""Examples:
\b
1. Runs evaluation on an existing dataset configuration:
1. Runs evaluation on an existing prediction output:
.. code:: sh
ptbench evaluate -vv --predictions=path/to/predictions.json --output-folder=path/to/results
2. Runs evaluation on an existing prediction output, tune threshold a priori on the `validation` set:
.. code:: sh
.. code:: sh
ptbench evaluate -vv montgomery --predictions-folder=path/to/predictions --output-folder=path/to/results
ptbench evaluate -vv --predictions=path/to/predictions.json --output-folder=path/to/results --threshold=validation
""",
)
@click.option(
"--output-folder",
"-o",
help="Path where to store the analysis result (created if does not exist)",
required=True,
default="results",
type=click.Path(),
cls=ResourceOption,
)
@click.option(
"--predictions-folder",
"--predictions",
"-p",
help="Path where predictions are currently stored",
required=True,
type=click.Path(exists=True, file_okay=False, dir_okay=True),
type=click.Path(
file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path
),
cls=ResourceOption,
)
@click.option(
"--datamodule",
"-d",
help="A lighting data module containing the training and validation sets.",
"--output-folder",
"-o",
help="Path where to store the analysis result (created if does not exist)",
required=True,
default="results",
type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
cls=ResourceOption,
)
@click.option(
"--threshold",
"-t",
help="This number is used to define positives and negatives from "
"probability maps, and report F1-scores (a priori). It "
"should either come from the training set or a separate validation set "
"to avoid biasing the analysis. Optionally, if you provide a multi-set "
"dataset as input, this may also be the name of an existing set from "
"which the threshold will be estimated (highest F1-score) and then "
"applied to the subsequent sets. This number is also used to print "
"the test set F1-score a priori performance",
default=None,
help="""This value is used to define positives and negatives from
probability outputs in predictions, and report performance measures on
**binary** classification tasks. It should either come from the training
set or a separate validation set to avoid biasing the analysis.
Optionally, if you provide a multi-split set of predictions as input, this
may also be the name of an existing split (e.g. ``validation``) from which
the threshold will be estimated (by calculating the threshold leading to
the highest F1-score on that set) and then applied to the subsequent
sets. This value is not used for multi-class classification tasks.""",
default=0.5,
show_default=False,
required=True,
cls=ResourceOption,
)
@click.option(
"--steps",
"-S",
help="This number is used to define the number of threshold steps to "
"consider when evaluating the highest possible F1-score on test data.",
default=1000,
show_default=True,
required=True,
type=click.STRING,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def evaluate(
output_folder: str,
predictions_folder: str,
datamodule: ConcatDataModule,
threshold: int | float | str,
steps: int,
**_,
predictions: pathlib.Path,
output_folder: pathlib.Path,
threshold: str | float,
**_, # ignored
) -> None:
"""Evaluates a CNN on a tuberculosis prediction task.
Note: batch size of 1 is required on the predictions.
"""
"""Evaluates predictions (from a model) on a binary classification task."""
from ..engine.evaluator import run
import json
import typing
datamodule.set_chunk_size(1, 1)
datamodule.model_transforms = []
import matplotlib.figure
datamodule.prepare_data()
datamodule.setup(stage="predict")
from matplotlib.backends.backend_pdf import PdfPages
dataloader = datamodule.predict_dataloader()
from ..engine.evaluator import (
aggregate_pr,
aggregate_roc,
aggregate_summaries,
run_binary,
)
threshold = _validate_threshold(threshold, dataloader)
with predictions.open("r") as f:
predict_data = json.load(f)
if isinstance(threshold, str):
if threshold in predict_data:
# it is the name of a split
# first run evaluation for reference dataset
logger.info(f"Evaluating threshold on '{threshold}' set")
_, _, f1_threshold, eer_threshold = run(
name=threshold,
predictions_folder=predictions_folder,
steps=steps,
)
from ..engine.evaluator import maxf1_threshold
if (f1_threshold is not None) and (eer_threshold is not None):
logger.info(f"Set --f1_threshold={f1_threshold:.5f}")
logger.info(f"Set --eer_threshold={eer_threshold:.5f}")
use_threshold = maxf1_threshold(predict_data[threshold])
logger.info(f"Setting --threshold={use_threshold:.5f}")
elif isinstance(threshold, float):
f1_threshold = threshold
eer_threshold = f1_threshold
else:
raise ValueError("Threshold value is neither a str or a float")
results_dict = { # type: ignore
"pred_data": defaultdict(dict),
"fig_score": defaultdict(dict),
"maxf1_threshold": defaultdict(dict),
"post_eer_threshold": defaultdict(dict),
}
for k in dataloader.keys():
if k.startswith("_"):
logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
continue
logger.info(f"Analyzing '{k}' set...")
pred_data, fig_score, maxf1_threshold, post_eer_threshold = run(
k,
predictions_folder,
f1_thresh=f1_threshold,
eer_thresh=eer_threshold,
steps=steps,
# we try to convert it to float and complain if that is not possible
try:
use_threshold = float(threshold)
except ValueError:
raise click.BadParameter(
f"""The value of --threshold=`{threshold}` does not match one
of the database split names ({', '.join(predict_data.keys())})
or can be converted to float. Check your input."""
)
results: dict[
str,
tuple[
dict[str, typing.Any],
dict[str, matplotlib.figure.Figure],
dict[str, typing.Any],
],
] = dict()
for k, v in predict_data.items():
logger.info(f"Analyzing split `{k}`...")
results[k] = run_binary(
name=k,
predictions=v,
threshold_a_priori=use_threshold,
)
results_dict["pred_data"][k] = pred_data
results_dict["fig_score"][k] = fig_score
results_dict["maxf1_threshold"][k] = maxf1_threshold
results_dict["post_eer_threshold"][k] = post_eer_threshold
rows = [v[0] for v in results.values()]
table = aggregate_summaries(rows, fmt="rst")
click.echo(table)
if output_folder is not None:
output_scores = os.path.join(output_folder, "scores.pdf")
if output_scores is not None:
output_scores = os.path.realpath(output_scores)
logger.info(f"Creating and saving scores at {output_scores}...")
os.makedirs(os.path.dirname(output_scores), exist_ok=True)
score_pdf = PdfPages(output_scores)
for fig in results_dict["fig_score"].values():
score_pdf.savefig(fig)
score_pdf.close()
data = {}
for subset_name in dataloader.keys():
data[subset_name] = {
"df": results_dict["pred_data"][subset_name],
"threshold": results_dict["post_eer_threshold"][threshold]
if isinstance(threshold, str)
else eer_threshold,
"threshold_type": f"posteriori [{threshold}]"
if isinstance(threshold, str)
else "priori",
output_folder.mkdir(parents=True, exist_ok=True)
table_path = output_folder / "summary.rst"
logger.info(f"Saving measures at `{table_path}`...")
with table_path.open("w") as f:
f.write(table)
figure_path = output_folder / "plots.pdf"
logger.info(f"Saving figures at `{figure_path}`...")
with PdfPages(figure_path) as pdf:
pr_curves = {
k: v[2]["precision_recall"] for k, v in results.items()
}
pr_fig = aggregate_pr(pr_curves)
pdf.savefig(pr_fig)
roc_curves = {k: v[2]["roc"] for k, v in results.items()}
roc_fig = aggregate_roc(roc_curves)
pdf.savefig(roc_fig)
# order ready-to-save figures by type instead of split
figures = {k: v[1] for k, v in results.items()}
keys = next(iter(figures.values())).keys()
figures_by_type = {
k: [v[k] for v in figures.values()] for k in keys
}
output_figure = os.path.join(output_folder, "plots.pdf")
if output_figure is not None:
output_figure = os.path.realpath(output_figure)
logger.info(f"Creating and saving plots at {output_figure}...")
os.makedirs(os.path.dirname(output_figure), exist_ok=True)
pdf = PdfPages(output_figure)
pdf.savefig(precision_recall_f1iso(data))
pdf.savefig(roc_curve(data))
pdf.close()
output_table = os.path.join(output_folder, "table.txt")
logger.info("Tabulating performance summary...")
table = performance_table(data, "rst")
click.echo(table)
if output_table is not None:
output_table = os.path.realpath(output_table)
logger.info(f"Saving table at {output_table}...")
os.makedirs(os.path.dirname(output_table), exist_ok=True)
with open(output_table, "w") as f:
f.write(table)
for group_figures in figures_by_type.values():
for f in group_figures:
pdf.savefig(f)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import tempfile
import urllib.request
from tqdm import tqdm
logger = logging.getLogger(__name__)
def download_to_tempfile(url, progress=False):
"""Downloads a file to a temporary named file and returns it.
Parameters
----------
url : str
The URL pointing to the file to download
progress : :py:class:`bool`, Optional
If a progress bar should be displayed for downloading the URL.
Returns
-------
f : :py:func:`tempfile.NamedTemporaryFile`
A named temporary file that contains the downloaded URL
"""
file_size = 0
response = urllib.request.urlopen(url)
meta = response.info()
if hasattr(meta, "getheaders"):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
if content_length is not None and len(content_length) > 0:
file_size = int(content_length[0])
progress &= bool(file_size)
f = tempfile.NamedTemporaryFile()
with tqdm(total=file_size, disable=not progress) as pbar:
while True:
buffer = response.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
pbar.update(len(buffer))
f.flush()
f.seek(0)
return f
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# SPDX-FileContributor: Kazuto Nakashima <k nakashima@irvs.ait.kyushu-u.ac.jp>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import torch
from torch.nn import functional as F
class BaseWrapper:
def __init__(self, model):
super().__init__()
self.device = next(model.parameters()).device
self.model_with_norm = model
self.model = model.model
self.handlers = [] # a set of hook function handlers
def _encode_one_hot(self, ids):
one_hot = torch.zeros_like(self.logits).to(self.device)
one_hot.scatter_(1, ids, 1.0)
return one_hot
def forward(self, image):
self.image_shape = image.shape[2:]
self.logits = self.model_with_norm(image)
self.probs = torch.sigmoid(self.logits)
return self.probs.sort(dim=1, descending=True) # ordered results
def backward(self, ids):
"""Class-specific backpropagation."""
one_hot = self._encode_one_hot(ids)
self.model_with_norm.zero_grad()
self.logits.backward(gradient=one_hot, retain_graph=True)
def generate(self):
raise NotImplementedError
def remove_hook(self):
"""Remove all the forward/backward hook functions."""
for handle in self.handlers:
handle.remove()
class GradCAM(BaseWrapper):
"""
"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
https://arxiv.org/pdf/1610.02391.pdf
Look at Figure 2 on page 4
"""
def __init__(self, model, candidate_layers=None):
super().__init__(model)
self.fmap_pool = {}
self.grad_pool = {}
self.candidate_layers = candidate_layers # list
def save_fmaps(key):
def forward_hook(module, input, output):
self.fmap_pool[key] = output.detach()
return forward_hook
def save_grads(key):
def backward_hook(module, grad_in, grad_out):
self.grad_pool[key] = grad_out[0].detach()
return backward_hook
# If any candidates are not specified, the hook is registered to all the layers.
for name, module in self.model.named_modules():
if self.candidate_layers is None or name in self.candidate_layers:
self.handlers.append(
module.register_forward_hook(save_fmaps(name))
)
self.handlers.append(
module.register_backward_hook(save_grads(name))
)
def _find(self, pool, target_layer):
if target_layer in pool.keys():
return pool[target_layer]
else:
raise ValueError(f"Invalid layer name: {target_layer}")
def generate(self, target_layer):
fmaps = self._find(self.fmap_pool, target_layer)
grads = self._find(self.grad_pool, target_layer)
weights = F.adaptive_avg_pool2d(grads, 1)
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
gcam = F.relu(gcam)
gcam = F.interpolate(
gcam, self.image_shape, mode="bilinear", align_corners=False
)
B, C, H, W = gcam.shape
gcam = gcam.view(B, -1)
gcam -= gcam.min(dim=1, keepdim=True)[0]
gcam /= gcam.max(dim=1, keepdim=True)[0]
gcam = gcam.view(B, C, H, W)
return gcam
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
from typing import Union
import torch
from PIL.Image import Image
from torchvision import transforms
def save_image(img: Union[torch.Tensor, Image], filepath: str) -> None:
"""Saves a PIL image or a tensor as an image at the specified destination.
Parameters
----------
img:
A torch.Tensor or PIL.Image to save
filepath:
The file in which to save the image. The format is inferred from the file extension, or defaults to png if not specified.
"""
if isinstance(img, torch.Tensor):
img = transforms.ToPILImage()(img)
root, ext = os.path.splitext(filepath)
if len(ext) == 0:
filepath = filepath + ".png"
img.save(filepath)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from collections import deque
import numpy
import scipy.special
import torch
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average."""
def __init__(self, window_size=20):
self.deque = deque(maxlen=window_size)
def update(self, value):
self.deque.append(value)
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque))
return d.mean().item()
def tricky_division(n, d):
"""Divides n by d.
Returns 0.0 in case of a division by zero
"""
return n / (d + (d == 0))
def base_measures(tp, fp, tn, fn):
"""Calculates measures from true/false positive and negative counts.
This function can return standard machine learning measures from true and
false positive counts of positives and negatives. For a thorough look into
these and alternate names for the returned values, please check Wikipedia's
entry on `Precision and Recall
<https://en.wikipedia.org/wiki/Precision_and_recall>`_.
Parameters
----------
tp : int
True positive count, AKA "hit"
fp : int
False positive count, AKA, "correct rejection"
tn : int
True negative count, AKA "false alarm", or "Type I error"
fn : int
False Negative count, AKA "miss", or "Type II error"
Returns
-------
precision : float
P, AKA positive predictive value (PPV). It corresponds arithmetically
to ``tp/(tp+fp)``. In the case ``tp+fp == 0``, this function returns
zero for precision.
recall : float
R, AKA sensitivity, hit rate, or true positive rate (TPR). It
corresponds arithmetically to ``tp/(tp+fn)``. In the special case
where ``tp+fn == 0``, this function returns zero for recall.
specificity : float
S, AKA selectivity or true negative rate (TNR). It
corresponds arithmetically to ``tn/(tn+fp)``. In the special case
where ``tn+fp == 0``, this function returns zero for specificity.
accuracy : float
A, see `Accuracy
<https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is
the proportion of correct predictions (both true positives and true
negatives) among the total number of pixels examined. It corresponds
arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes
both true-negatives and positives in the numerator, what makes it
sensitive to data or regions without annotations.
jaccard : float
J, see `Jaccard Index or Similarity
<https://en.wikipedia.org/wiki/Jaccard_index>`_. It corresponds
arithmetically to ``tp/(tp+fp+fn)``. In the special case where
``tn+fp+fn == 0``, this function returns zero for the Jaccard index.
The Jaccard index depends on a TP-only numerator, similarly to the F1
score. For regions where there are no annotations, the Jaccard index
will always be zero, irrespective of the model output. Accuracy may be
a better proxy if one needs to consider the true abscence of
annotations in a region as part of the measure.
f1_score : float
F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_. It
corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``.
In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function
returns zero for the Jaccard index. The F1 or Dice score depends on a
TP-only numerator, similarly to the Jaccard index. For regions where
there are no annotations, the F1-score will always be zero,
irrespective of the model output. Accuracy may be a better proxy if
one needs to consider the true abscence of annotations in a region as
part of the measure.
"""
return (
tricky_division(tp, tp + fp), # precision
tricky_division(tp, tp + fn), # recall
tricky_division(tn, fp + tn), # specificity
tricky_division(tp + tn, tp + fp + fn + tn), # accuracy
tricky_division(tp, tp + fp + fn), # jaccard index
tricky_division(2 * tp, (2 * tp) + fp + fn), # f1-score
)
def beta_credible_region(successes, failures, lambda_, coverage):
"""Returns the mode, upper and lower bounds of the equal-tailed credible
region of a probability estimate following Bernoulli trials.
This implemetnation is based on [GOUTTE-2005]_. It assumes :math:`k`
successes and :math:`l` failures (:math:`n = k+l` total trials) are issued
from a series of Bernoulli trials (likelihood is binomial). The posterior
is derivated using the Bayes Theorem with a beta prior. As there is no
reason to favour high vs. low precision, we use a symmetric Beta prior
(:math:`\\alpha=\\beta`):
.. math::
P(p|k,n) &= \\frac{P(k,n|p)P(p)}{P(k,n)} \\\\
P(p|k,n) &= \\frac{\\frac{n!}{k!(n-k)!}p^{k}(1-p)^{n-k}P(p)}{P(k)} \\\\
P(p|k,n) &= \\frac{1}{B(k+\\alpha, n-k+\beta)}p^{k+\\alpha-1}(1-p)^{n-k+\\beta-1} \\\\
P(p|k,n) &= \\frac{1}{B(k+\\alpha, n-k+\\alpha)}p^{k+\\alpha-1}(1-p)^{n-k+\\alpha-1}
The mode for this posterior (also the maximum a posteriori) is:
.. math::
\\text{mode}(p) = \\frac{k+\\lambda-1}{n+2\\lambda-2}
Concretely, the prior may be flat (all rates are equally likely,
:math:`\\lambda=1`) or we may use Jeoffrey's prior
(:math:`\\lambda=0.5`), that is invariant through re-parameterisation.
Jeffrey's prior indicate that rates close to zero or one are more likely.
The mode above works if :math:`k+{\\alpha},n-k+{\\alpha} > 1`, which is
usually the case for a resonably well tunned system, with more than a few
samples for analysis. In the limit of the system performance, :math:`k`
may be 0, which will make the mode become zero.
For our purposes, it may be more suitable to represent :math:`n = k + l`,
with :math:`k`, the number of successes and :math:`l`, the number of
failures in the binomial experiment, and find this more suitable
representation:
.. math::
P(p|k,l) &= \\frac{1}{B(k+\\alpha, l+\\alpha)}p^{k+\\alpha-1}(1-p)^{l+\\alpha-1} \\\\
\\text{mode}(p) &= \\frac{k+\\lambda-1}{k+l+2\\lambda-2}
This can be mapped to most rates calculated in the context of binary
classification this way:
* Precision or Positive-Predictive Value (PPV): p = TP/(TP+FP), so k=TP, l=FP
* Recall, Sensitivity, or True Positive Rate: r = TP/(TP+FN), so k=TP, l=FN
* Specificity or True Negative Rage: s = TN/(TN+FP), so k=TN, l=FP
* F1-score: f1 = 2TP/(2TP+FP+FN), so k=2TP, l=FP+FN
* Accuracy: acc = TP+TN/(TP+TN+FP+FN), so k=TP+TN, l=FP+FN
* Jaccard: j = TP/(TP+FP+FN), so k=TP, l=FP+FN
Contrary to frequentist approaches, in which one can only
say that if the test were repeated an infinite number of times,
and one constructed a confidence interval each time, then X%
of the confidence intervals would contain the true rate, here
we can say that given our observed data, there is a X% probability
that the true value of :math:`k/n` falls within the provided
interval.
.. note::
For a disambiguation with Confidence Interval, read
https://en.wikipedia.org/wiki/Credible_interval.
Parameters
==========
successes : int
Number of successes observed on the experiment
failures : int
Number of failures observed on the experiment
lambda__ : :py:class:`float`, Optional
The parameterisation of the Beta prior to consider. Use
:math:`\\lambda=1` for a flat prior. Use :math:`\\lambda=0.5` for
Jeffrey's prior (the default).
coverage : :py:class:`float`, Optional
A floating-point number between 0 and 1.0 indicating the
coverage you're expecting. A value of 0.95 will ensure 95%
of the area under the probability density of the posterior
is covered by the returned equal-tailed interval.
Returns
=======
mean : float
The mean of the posterior distribution
mode : float
The mode of the posterior distribution
lower, upper: float
The lower and upper bounds of the credible region
"""
# we return the equally-tailed range
right = (1.0 - coverage) / 2 # half-width in each side
lower = scipy.special.betaincinv(
successes + lambda_, failures + lambda_, right
)
upper = scipy.special.betaincinv(
successes + lambda_, failures + lambda_, 1.0 - right
)
# evaluate mean and mode (https://en.wikipedia.org/wiki/Beta_distribution)
alpha = successes + lambda_
beta = failures + lambda_
E = alpha / (alpha + beta)
# the mode of a beta distribution is a bit tricky
if alpha > 1 and beta > 1:
mode = (alpha - 1) / (alpha + beta - 2)
elif alpha == 1 and beta == 1:
# In the case of precision, if the threshold is close to 1.0, both TP
# and FP can be zero, which may cause this condition to be reached, if
# the prior is exactly 1 (flat prior). This is a weird situation,
# because effectively we are trying to compute the posterior when the
# total number of experiments is zero. So, only the prior counts - but
# the prior is flat, so we should just pick a value. We choose the
# middle of the range.
mode = 0.0 # any value would do, we just pick this one
elif alpha <= 1 and beta > 1:
mode = 0.0
elif alpha > 1 and beta <= 1:
mode = 1.0
else: # elif alpha < 1 and beta < 1:
# in the case of precision, if the threshold is close to 1.0, both TP
# and FP can be zero, which may cause this condition to be reached, if
# the prior is smaller than 1. This is a weird situation, because
# effectively we are trying to compute the posterior when the total
# number of experiments is zero. So, only the prior counts - but the
# prior is bimodal, so we should just pick a value. We choose the
# left of the range.
mode = 0.0 # could also be 1.0 as the prior is bimodal
return E, mode, lower, upper
def bayesian_measures(tp, fp, tn, fn, lambda_, coverage):
r"""Calculates mean and mode from true/false positive and negative counts
with credible regions.
This function can return bayesian estimates of standard machine learning
measures from true and false positive counts of positives and negatives.
For a thorough look into these and alternate names for the returned values,
please check Wikipedia's entry on `Precision and Recall
<https://en.wikipedia.org/wiki/Precision_and_recall>`_. See
:py:func:`beta_credible_region` for details on the calculation of returned
values.
Parameters
----------
tp : int
True positive count, AKA "hit"
fp : int
False positive count, AKA "false alarm", or "Type I error"
tn : int
True negative count, AKA "correct rejection"
fn : int
False Negative count, AKA "miss", or "Type II error"
lambda_ : float
The parameterisation of the Beta prior to consider. Use
:math:`\lambda=1` for a flat prior. Use :math:`\lambda=0.5` for
Jeffrey's prior.
coverage : float
A floating-point number between 0 and 1.0 indicating the
coverage you're expecting. A value of 0.95 will ensure 95%
of the area under the probability density of the posterior
is covered by the returned equal-tailed interval.
Returns
-------
precision : (float, float, float, float)
P, AKA positive predictive value (PPV), mean, mode and credible
intervals (95% CI). It corresponds arithmetically
to ``tp/(tp+fp)``.
recall : (float, float, float, float)
R, AKA sensitivity, hit rate, or true positive rate (TPR), mean, mode
and credible intervals (95% CI). It corresponds arithmetically to
``tp/(tp+fn)``.
specificity : (float, float, float, float)
S, AKA selectivity or true negative rate (TNR), mean, mode and credible
intervals (95% CI). It corresponds arithmetically to ``tn/(tn+fp)``.
accuracy : (float, float, float, float)
A, mean, mode and credible intervals (95% CI). See `Accuracy
<https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is
the proportion of correct predictions (both true positives and true
negatives) among the total number of pixels examined. It corresponds
arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes
both true-negatives and positives in the numerator, what makes it
sensitive to data or regions without annotations.
jaccard : (float, float, float, float)
J, mean, mode and credible intervals (95% CI). See `Jaccard Index or
Similarity <https://en.wikipedia.org/wiki/Jaccard_index>`_. It
corresponds arithmetically to ``tp/(tp+fp+fn)``. The Jaccard index
depends on a TP-only numerator, similarly to the F1 score. For regions
where there are no annotations, the Jaccard index will always be zero,
irrespective of the model output. Accuracy may be a better proxy if
one needs to consider the true abscence of annotations in a region as
part of the measure.
f1_score : (float, float, float, float)
F1, mean, mode and credible intervals (95% CI). See `F1-score
<https://en.wikipedia.org/wiki/F1_score>`_. It corresponds
arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``. The F1 or
Dice score depends on a TP-only numerator, similarly to the Jaccard
index. For regions where there are no annotations, the F1-score will
always be zero, irrespective of the model output. Accuracy may be a
better proxy if one needs to consider the true abscence of annotations
in a region as part of the measure.
"""
return (
beta_credible_region(tp, fp, lambda_, coverage), # precision
beta_credible_region(tp, fn, lambda_, coverage), # recall
beta_credible_region(tn, fp, lambda_, coverage), # specificity
beta_credible_region(tp + tn, fp + fn, lambda_, coverage), # accuracy
beta_credible_region(tp, fp + fn, lambda_, coverage), # jaccard index
beta_credible_region(2 * tp, fp + fn, lambda_, coverage), # f1-score
)
def get_centered_maxf1(f1_scores, thresholds):
"""Return the centered max F1 score threshold when multiple threshold give
the same max F1 score.
Parameters
----------
f1_scores : numpy.ndarray
1D array of f1 scores
thresholds : numpy.ndarray
1D array of thresholds
Returns
-------
max F1 score: float
threshold: float
"""
maxf1 = f1_scores.max()
maxf1_indices = numpy.where(f1_scores == maxf1)[0]
# If multiple thresholds give the same max F1 score
if len(maxf1_indices) > 1:
mean_maxf1_index = int(round(numpy.mean(maxf1_indices)))
else:
mean_maxf1_index = maxf1_indices[0]
return maxf1, thresholds[mean_maxf1_index]
# SPDX-FileCopyrightText: Copyright Facebook, Inc. and its affiliates. All Rights Reserved.
#
# SPDX-License-Identifier: GPL-3.0-or-later
# Original code from: https://github.com/facebookresearch/maskrcnn-benchmark
import logging
from collections import OrderedDict
logger = logging.getLogger(__name__)
import torch
def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
"""
Strategy: suppose that the models that we will create will have prefixes appended
to each of its keys, for example due to an extra level of nesting that the original
pre-trained weights from ImageNet won't contain. For example, model.state_dict()
might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
res2.conv1.weight. We thus want to match both parameters together.
For that, we look for each model weight, look among all loaded keys if there is one
that is a suffix of the current weight name, and use it if that's the case.
If multiple matches exist, take the one with longest size
of the corresponding name. For example, for the same model as before, the pretrained
weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
we want to match backbone[0].body.conv1.weight to conv1.weight, and
backbone[0].body.res2.conv1.weight to res2.conv1.weight.
"""
current_keys = sorted(list(model_state_dict.keys()))
loaded_keys = sorted(list(loaded_state_dict.keys()))
# get a matrix of string matches, where each (i, j) entry correspond to the size of the
# loaded_key string, if it matches
match_matrix = [
len(j) if i.endswith(j) else 0
for i in current_keys
for j in loaded_keys
]
match_matrix = torch.as_tensor(match_matrix).view(
len(current_keys), len(loaded_keys)
)
max_match_size, idxs = match_matrix.max(1)
# remove indices that correspond to no-match
idxs[max_match_size == 0] = -1
# used for logging
max_size = max([len(key) for key in current_keys]) if current_keys else 1
max_size_loaded = (
max([len(key) for key in loaded_keys]) if loaded_keys else 1
)
log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
for idx_new, idx_old in enumerate(idxs.tolist()):
if idx_old == -1:
continue
key = current_keys[idx_new]
key_old = loaded_keys[idx_old]
model_state_dict[key] = loaded_state_dict[key_old]
logger.debug(
log_str_template.format(
key,
max_size,
key_old,
max_size_loaded,
tuple(loaded_state_dict[key_old].shape),
)
)
def strip_prefix_if_present(state_dict, prefix):
keys = sorted(state_dict.keys())
if not all(key.startswith(prefix) for key in keys):
return state_dict
stripped_state_dict = OrderedDict()
for key, value in state_dict.items():
stripped_state_dict[key.replace(prefix, "")] = value
return stripped_state_dict
def load_state_dict(model, loaded_state_dict):
model_state_dict = model.state_dict()
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching
loaded_state_dict = strip_prefix_if_present(
loaded_state_dict, prefix="module."
)
align_and_update_state_dicts(model_state_dict, loaded_state_dict)
# use strict loading
model.load_state_dict(model_state_dict)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
# Adapted from:
# https://github.com/pytorch/pytorch/blob/master/torch/hub.py
# https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/checkpoint.py
import hashlib
import os
import re
import shutil
import sys
import tempfile
from urllib.parse import urlparse
from urllib.request import urlopen
from tqdm import tqdm
modelurls = {
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}
"""URLs of pre-trained models (backbones)"""
def download_url_to_file(url, dst, hash_prefix, progress):
file_size = None
u = urlopen(url)
meta = u.info()
if hasattr(meta, "getheaders"):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
if content_length is not None and len(content_length) > 0:
file_size = int(content_length[0])
f = tempfile.NamedTemporaryFile(delete=False)
try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
with tqdm(total=file_size, disable=not progress) as pbar:
while True:
buffer = u.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
if hash_prefix is not None:
sha256.update(buffer)
pbar.update(len(buffer))
f.close()
if hash_prefix is not None:
digest = sha256.hexdigest()
if digest[: len(hash_prefix)] != hash_prefix:
raise RuntimeError(
'invalid hash value (expected "{}", got "{}")'.format(
hash_prefix, digest
)
)
shutil.move(f.name, dst)
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)
HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
def cache_url(url, model_dir=None, progress=True):
r"""Loads the Torch serialized object at the given URL.
If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
The default value of `model_dir` is ``$TORCH_HOME/models`` where
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
progress (bool, optional): whether or not to display a progress bar to stderr
"""
if model_dir is None:
torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))
model_dir = os.getenv(
"TORCH_MODEL_ZOO", os.path.join(torch_home, "models")
)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
hash_prefix = HASH_REGEX.search(filename)
if hash_prefix is not None:
hash_prefix = hash_prefix.group(1)
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
return cached_file
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import contextlib
from itertools import cycle
import matplotlib
import numpy
from sklearn.metrics import auc
from sklearn.metrics import precision_recall_curve as pr_curve
from sklearn.metrics import roc_curve as r_curve
matplotlib.use("agg")
import logging
import matplotlib.pyplot as plt
logger = logging.getLogger(__name__)
@contextlib.contextmanager
def _precision_recall_canvas(title=None):
"""Generates a canvas to draw precision-recall curves.
Works like a context manager, yielding a figure and an axes set in which
the precision-recall curves should be added to. The figure already
contains F1-ISO lines and is preset to a 0-1 square region. Once the
context is finished, ``fig.tight_layout()`` is called.
Parameters
----------
title : :py:class:`str`, Optional
Optional title to add to this plot
Yields
------
figure : matplotlib.figure.Figure
The figure that should be finally returned to the user
axes : matplotlib.figure.Axes
An axis set where to precision-recall plots should be added to
"""
fig, axes1 = plt.subplots(1)
# Names and bounds
axes1.set_xlabel("Recall")
axes1.set_ylabel("Precision")
axes1.set_xlim([0.0, 1.0])
axes1.set_ylim([0.0, 1.0])
if title is not None:
axes1.set_title(title)
axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
axes2 = axes1.twinx()
# Annotates plot with F1-score iso-lines
f_scores = numpy.linspace(0.1, 0.9, num=9)
tick_locs = []
tick_labels = []
for f_score in f_scores:
x = numpy.linspace(0.01, 1)
y = f_score * x / (2 * x - f_score)
(l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1)
tick_locs.append(y[-1])
tick_labels.append("%.1f" % f_score)
axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False)
axes2.set_ylabel("iso-F", color="green", alpha=0.3)
axes2.set_ylim([0.0, 1.0])
axes2.yaxis.set_label_coords(1.015, 0.97)
axes2.set_yticks(tick_locs) # notice these are invisible
for k in axes2.set_yticklabels(tick_labels):
k.set_color("green")
k.set_alpha(0.3)
k.set_size(8)
# we should see some of axes 1 axes
axes1.spines["right"].set_visible(False)
axes1.spines["top"].set_visible(False)
axes1.spines["left"].set_position(("data", -0.015))
axes1.spines["bottom"].set_position(("data", -0.015))
# we shouldn't see any of axes 2 axes
axes2.spines["right"].set_visible(False)
axes2.spines["top"].set_visible(False)
axes2.spines["left"].set_visible(False)
axes2.spines["bottom"].set_visible(False)
# yield execution, lets user draw precision-recall plots, and the legend
# before tighteneing the layout
yield fig, axes1
plt.tight_layout()
def precision_recall_f1iso(data):
"""Creates a precision-recall plot.
This function creates and returns a Matplotlib figure with a
precision-recall plot. The plot will be annotated with F1-score
iso-lines (in which the F1-score maintains the same value).
Parameters
----------
data : dict
A dictionary in which keys are strings defining plot labels and values
are dictionaries with two entries:
* ``df``: :py:class:`pandas.DataFrame`
A dataframe that is produced by our predictor engine containing
the following columns: ``filename``, ``likelihood``,
``ground_truth``.
* ``threshold``: :py:class:`list`
A threshold for each set. Not used here.
Returns
-------
figure : matplotlib.figure.Figure
A matplotlib figure you can save or display (uses an ``agg`` backend)
"""
lines = ["-", "--", "-.", ":"]
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colorcycler = cycle(colors)
linecycler = cycle(lines)
with _precision_recall_canvas(title=None) as (fig, axes):
legend = []
for name, value in data.items():
df = value["df"]
# plots Recall/Precision curve
prec, recall, _ = pr_curve(df["ground_truth"], df["likelihood"])
_auc = auc(recall, prec)
label = f"{name} (AUC={_auc:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = axes.plot(recall, prec, color=color, linestyle=style)
legend.append((line, label))
if len(label) > 1:
axes.legend(
[k[0] for k in legend],
[k[1] for k in legend],
loc="lower left",
fancybox=True,
framealpha=0.7,
)
return fig
def roc_curve(data, title=None):
"""Creates a ROC plot.
This function creates and returns a Matplotlib figure with a
ROC plot.
Parameters
----------
data : dict
A dictionary in which keys are strings defining plot labels and values
are dictionaries with two entries:
* ``df``: :py:class:`pandas.DataFrame`
A dataframe that is produced by our predictor engine containing
the following columns: ``filename``, ``likelihood``,
``ground_truth``.
* ``threshold``: :py:class:`list`
A threshold for each set. Not used here.
Returns
-------
figure : matplotlib.figure.Figure
A matplotlib figure you can save or display (uses an ``agg`` backend)
"""
fig, axes = plt.subplots(1)
# Names and bounds
axes.set_xlabel("1 - specificity")
axes.set_ylabel("Sensitivity")
axes.set_xlim([0.0, 1.0])
axes.set_ylim([0.0, 1.0])
# we should see some of axes 1 axes
axes.spines["right"].set_visible(False)
axes.spines["top"].set_visible(False)
axes.spines["left"].set_position(("data", -0.015))
axes.spines["bottom"].set_position(("data", -0.015))
if title is not None:
axes.set_title(title)
axes.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2)
plt.tight_layout()
lines = ["-", "--", "-.", ":"]
colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colorcycler = cycle(colors)
linecycler = cycle(lines)
legend = []
for name, value in data.items():
df = value["df"]
# plots roc curve
fpr, tpr, _ = r_curve(df["ground_truth"], df["likelihood"])
_auc = auc(fpr, tpr)
label = f"{name} (AUC={_auc:.2f})"
color = next(colorcycler)
style = next(linecycler)
(line,) = axes.plot(fpr, tpr, color=color, linestyle=style)
legend.append((line, label))
if len(label) > 1:
axes.legend(
[k[0] for k in legend],
[k[1] for k in legend],
loc="lower right",
fancybox=True,
framealpha=0.7,
)
return fig
def relevance_analysis_plot(data, title=None):
"""Create an histogram plot to show the relative importance of features.
Parameters
----------
data : :py:class:`list`
The list of values (one for each feature)
Returns
-------
figure : matplotlib.figure.Figure
A matplotlib figure you can save or display (uses an ``agg`` backend)
"""
fig, axes = plt.subplots(1, 1, figsize=(6, 6))
# Names and bounds
axes.set_xlabel("Features")
axes.set_ylabel("Importance")
# we should see some of axes 1 axes
axes.spines["right"].set_visible(False)
axes.spines["top"].set_visible(False)
if title is not None:
axes.set_title(title)
# 818C2E = likely
# F2921D = could be
# 8C3503 = unlikely
labels = [
"Cardiomegaly",
"Emphysema",
"Pleural effusion",
"Hernia",
"Infiltration",
"Mass",
"Nodule",
"Atelectasis",
"Pneumothorax",
"Pleural thickening",
"Pneumonia",
"Fibrosis",
"Edema",
"Consolidation",
]
bars = axes.bar(labels, data, color="#8C3503")
bars[2].set_color("#818C2E")
bars[4].set_color("#818C2E")
bars[10].set_color("#818C2E")
bars[5].set_color("#F2921D")
bars[6].set_color("#F2921D")
bars[7].set_color("#F2921D")
bars[11].set_color("#F2921D")
bars[13].set_color("#F2921D")
for tick in axes.get_xticklabels():
tick.set_rotation(90)
fig.tight_layout()
return fig
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import tabulate
import torch
from sklearn.metrics import auc
from sklearn.metrics import precision_recall_curve as pr_curve
from sklearn.metrics import roc_curve as r_curve
from ..engine.evaluator import posneg
from ..utils.measure import base_measures, bayesian_measures
def performance_table(data, fmt):
"""Tables result comparison in a given format.
Parameters
----------
data : dict
A dictionary in which keys are strings defining plot labels and values
are dictionaries with two entries:
* ``df``: :py:class:`pandas.DataFrame`
A dataframe that is produced by our predictor engine containing
the following columns: ``filename``, ``likelihood``,
``ground_truth``.
* ``threshold``: :py:class:`list`
A threshold to compute measures.
fmt : str
One of the formats supported by tabulate.
Returns
-------
table : str
A table in a specific format
"""
headers = [
"Dataset",
"T",
"T Type",
"F1 (95% CI)",
"Prec (95% CI)",
"Recall/Sen (95% CI)",
"Spec (95% CI)",
"Acc (95% CI)",
"AUC (PRC)",
"AUC (ROC)",
]
table = []
for k, v in data.items():
entry = [
k,
v["threshold"],
v["threshold_type"],
]
df = v["df"]
gt = torch.tensor(df["ground_truth"].values)
pred = torch.tensor(df["likelihood"].values)
threshold = v["threshold"]
tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)
# calc measures from scalars
tp_count = torch.sum(tp_tensor).item()
fp_count = torch.sum(fp_tensor).item()
tn_count = torch.sum(tn_tensor).item()
fn_count = torch.sum(fn_tensor).item()
base_m = base_measures(
tp_count,
fp_count,
tn_count,
fn_count,
)
bayes_m = bayesian_measures(
tp_count,
fp_count,
tn_count,
fn_count,
lambda_=1,
coverage=0.95,
)
# statistics based on the "assigned" threshold (a priori, less biased)
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[5], bayes_m[5][2], bayes_m[5][3]
)
) # f1
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[0], bayes_m[0][2], bayes_m[0][3]
)
) # precision
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[1], bayes_m[1][2], bayes_m[1][3]
)
) # recall/sensitivity
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[2], bayes_m[2][2], bayes_m[2][3]
)
) # specificity
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[3], bayes_m[3][2], bayes_m[3][3]
)
) # accuracy
prec, recall, _ = pr_curve(gt, pred)
fpr, tpr, _ = r_curve(gt, pred)
entry.append(auc(recall, prec))
entry.append(auc(fpr, tpr))
table.append(entry)
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment