diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py index 9d14f79078dc255ac2dec738584675331c74111f..8df2b58067585604cf9ed2b4bb1497ce296eabd1 100644 --- a/src/ptbench/engine/evaluator.py +++ b/src/ptbench/engine/evaluator.py @@ -3,367 +3,516 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Defines functionality for the evaluation of predictions.""" +import contextlib +import itertools import logging -import os -import re +import typing -from collections.abc import Iterable -from typing import Optional +from collections.abc import Iterable, Iterator -import matplotlib.pyplot as plt +import matplotlib.figure import numpy -import pandas as pd -import torch +import numpy.typing +import sklearn.metrics +import tabulate -from sklearn import metrics +from matplotlib import pyplot as plt -from ..utils.measure import base_measures, get_centered_maxf1 +from ..models.typing import BinaryPrediction logger = logging.getLogger(__name__) -def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float: - """Evaluates the EER threshold from negative and positive scores. +def eer_threshold(predictions: Iterable[BinaryPrediction]) -> float: + """Calculates the (approximate) threshold leading to the equal error rate. Parameters ---------- + predictions + An iterable of multiple + :py:data:`ptbench.models.typing.BinaryPrediction`'s. - neg : - Negative scores - - pos : - Positive scores - - - Returns: + Returns + ------- The EER threshold value. """ from scipy.interpolate import interp1d from scipy.optimize import brentq - y_predictions = pd.concat((neg, pos)) - y_true = numpy.concatenate((numpy.zeros_like(neg), numpy.ones_like(pos))) + y_scores = [k[2] for k in predictions] + y_labels = [k[1] for k in predictions] - fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1) + fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_labels, y_scores) eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) return float(interp1d(fpr, thresholds)(eer)) -def posneg( - pred, gt, threshold -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Calculates true and false positives and negatives. +def _get_centered_maxf1( + f1_scores: numpy.typing.NDArray, thresholds: numpy.typing.NDArray +): + """Return the centered max F1 score threshold when multiple thresholds give + the same max F1 score. Parameters ---------- + f1_scores + 1D array of f1 scores + thresholds + 1D array of thresholds - pred : - Pixel-wise predictions. + Returns + ------- + A tuple with the maximum F1-score and the "centered" threshold. + """ + maxf1 = f1_scores.max() + maxf1_indices = numpy.where(f1_scores == maxf1)[0] - gt : - Ground-truth (annotations). + # If multiple thresholds give the same max F1 score + if len(maxf1_indices) > 1: + mean_maxf1_index = int(round(numpy.mean(maxf1_indices))) + else: + mean_maxf1_index = maxf1_indices[0] + + return maxf1, thresholds[mean_maxf1_index] + + +def maxf1_threshold(predictions: Iterable[BinaryPrediction]) -> float: + """Calculates the threshold leading to the maximum F1-score on a precision- + recall curve. + + Parameters + ---------- + predictions + An iterable of multiple + :py:data:`ptbench.models.typing.BinaryPrediction`'s. - threshold : - A particular threshold in which to calculate the performance - measures. Returns ------- + The threshold value leading to the maximum F1-score on the provided set + of predictions. + """ + y_scores = [k[2] for k in predictions] + y_labels = [k[1] for k in predictions] - tp_tensor: - The true positive values. + precision, recall, thresholds = sklearn.metrics.precision_recall_curve( + y_labels, y_scores + ) - fp_tensor: - The false positive values. + numerator = 2 * recall * precision + denom = recall + precision + f1_scores = numpy.divide( + numerator, denom, out=numpy.zeros_like(denom), where=(denom != 0) + ) - tn_tensor: - The true negative values. + _, maxf1_threshold = _get_centered_maxf1(f1_scores, thresholds) + return maxf1_threshold - fn_tensor: - The false negative values. - """ - # threshold - binary_pred = torch.gt(pred, threshold) +def _score_plot( + labels: numpy.typing.NDArray, + scores: numpy.typing.NDArray, + title: str, + threshold: float, +) -> matplotlib.figure.Figure: + """Plots the normalized score distributions for all systems. - # equals and not-equals - equals = torch.eq(binary_pred, gt).type(torch.uint8) - notequals = torch.ne(binary_pred, gt).type(torch.uint8) + Parameters + ---------- + labels + True labels (negatives and positives) for each entry in ``scores`` + scores + Likelihoods provided by the classification model, for each sample + title + Title of the plot. + threshold + Shows where the threshold is in the figure - # true positives - tp_tensor = (gt * binary_pred).type(torch.uint8) - # false positives - fp_tensor = torch.eq((binary_pred + tp_tensor), 1).type(torch.uint8) + Returns + ------- + A single (matplotlib) plot containing the score distribution, ready to + be saved to disk or displayed. + """ + + fig, ax = plt.subplots(1, 1) + assert isinstance(fig, matplotlib.figure.Figure) + ax = typing.cast(plt.Axes, ax) # gets editor to behave + + # Here, we configure the "style" of our plot + ax.set_xlim([0, 1]) + ax.set_title(title) + ax.set_xlabel("Score") + ax.set_ylabel("Normalized count") + ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) + + # Only show ticks on the left and bottom spines + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) + ax.get_xaxis().tick_bottom() + ax.get_yaxis().tick_left() + + positives = scores[labels > 0.5] + negatives = scores[labels < 0.5] + + ax.hist(positives, bins="auto", label="positives", density=True, alpha=0.7) + ax.hist(negatives, bins="auto", label="negatives", density=True, alpha=0.7) + + # Adds threshold line (dotted red) + ax.axvline( + threshold, # type: ignore + color="red", + lw=2, + alpha=0.75, + ls="dotted", + label="threshold", + ) - # true negatives - tn_tensor = (equals - tp_tensor).type(torch.uint8) + # Adds a nice legend + ax.legend( + title="Max F1-scores", + fancybox=True, + framealpha=0.7, + ) - # false negatives - fn_tensor = notequals - fp_tensor.type(torch.uint8) + # Makes sure the figure occupies most of the possible space + fig.tight_layout() - return tp_tensor, fp_tensor, tn_tensor, fn_tensor + return fig -def sample_measures_for_threshold( - pred: torch.Tensor, gt: torch.Tensor, threshold: float -) -> tuple[float, float, float, float, float]: - """Calculates measures on one single sample, for a specific threshold. +def run_binary( + name: str, + predictions: Iterable[BinaryPrediction], + threshold_a_priori: float | None = None, +) -> tuple[ + dict[str, typing.Any], + dict[str, matplotlib.figure.Figure], + dict[str, typing.Any], +]: + """Runs inference and calculates measures for binary classification. Parameters ---------- + name + The name of subset to load. + predictions + A list of predictions to consider for measurement + threshold_a_priori + A threshold to use, evaluated *a priori*, if must report single values. + If this value is not provided, a *a posteriori* threshold is calculated + on the input scores. This is a biased estimator. - pred : - Pixel-wise predictions. - - gt : - Ground-truth (annotations). - - threshold : - A particular threshold in which to calculate the performance - measures. Returns ------- + A tuple containing the following entries: - precision : float - P, AKA positive predictive value (PPV). It corresponds arithmetically - to ``tp/(tp+fp)``. In the case ``tp+fp == 0``, this function returns - zero for precision. - - recall : float - R, AKA sensitivity, hit rate, or true positive rate (TPR). It - corresponds arithmetically to ``tp/(tp+fn)``. In the special case - where ``tp+fn == 0``, this function returns zero for recall. - - specificity : float - S, AKA selectivity or true negative rate (TNR). It - corresponds arithmetically to ``tn/(tn+fp)``. In the special case - where ``tn+fp == 0``, this function returns zero for specificity. - - accuracy : float - A, see `Accuracy - <https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is - the proportion of correct predictions (both true positives and true - negatives) among the total number of pixels examined. It corresponds - arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes - both true-negatives and positives in the numerator, what makes it - sensitive to data or regions without annotations. - - jaccard : float - J, see `Jaccard Index or Similarity - <https://en.wikipedia.org/wiki/Jaccard_index>`_. It corresponds - arithmetically to ``tp/(tp+fp+fn)``. In the special case where - ``tn+fp+fn == 0``, this function returns zero for the Jaccard index. - The Jaccard index depends on a TP-only numerator, similarly to the F1 - score. For regions where there are no annotations, the Jaccard index - will always be zero, irrespective of the model output. Accuracy may be - a better proxy if one needs to consider the true abscence of - annotations in a region as part of the measure. - - f1_score : float - F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_. It - corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``. - In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function - returns zero for the Jaccard index. The F1 or Dice score depends on a - TP-only numerator, similarly to the Jaccard index. For regions where - there are no annotations, the F1-score will always be zero, - irrespective of the model output. Accuracy may be a better proxy if - one needs to consider the true abscence of annotations in a region as - part of the measure. + * summary: A dictionary containing the performance summary on the + specified threshold + * figures: A dictionary of generated standalone figures + * curves: A dictionary containing curves that can potentially be combined + with other prediction lists to make aggregate plots. """ - tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold) + y_scores = numpy.array([k[2] for k in predictions]) # likelihoods + y_labels = numpy.array([k[1] for k in predictions]) # integers + neg_label = y_labels.min() + pos_label = y_labels.max() + + use_threshold = threshold_a_priori + if use_threshold is None: + use_threshold = maxf1_threshold(predictions) + logger.warning( + f"User did not pass an *a priori* threshold for the evaluation " + f"of split `{name}`. Using threshold a posteriori (biased) with value " + f"`{use_threshold:.4f}`" + ) - # calc measures from scalars - tp_count = torch.sum(tp_tensor).item() - fp_count = torch.sum(fp_tensor).item() - tn_count = torch.sum(tn_tensor).item() - fn_count = torch.sum(fn_tensor).item() - return base_measures(tp_count, fp_count, tn_count, fn_count) + y_predictions = numpy.where(y_scores >= use_threshold, pos_label, neg_label) + # point measures on threshold + summary = dict( + split=name, + threshold=use_threshold, + threshold_a_posteriori=(threshold_a_priori is None), + precision=sklearn.metrics.precision_score( + y_labels, y_predictions, pos_label=pos_label + ), + recall=sklearn.metrics.recall_score( + y_labels, y_predictions, pos_label=pos_label + ), + specificity=sklearn.metrics.recall_score( + y_labels, y_predictions, pos_label=neg_label + ), + accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), + f1_score=sklearn.metrics.f1_score( + y_labels, y_predictions, pos_label=pos_label + ), + ) -def run( - name: str, - predictions_folder: str, - f1_thresh: Optional[float] = None, - eer_thresh: Optional[float] = None, - steps: Optional[int] = 1000, -): - """Runs inference and calculates measures. + # figures: score distributions + figures = dict( + scores=_score_plot( + y_labels, + y_scores, + f"Score distribution (split: {name})", + use_threshold, + ), + ) - Parameters - --------- + # curves: ROC and precision recall + curves = dict( + roc=sklearn.metrics.roc_curve(y_labels, y_scores, pos_label=pos_label), + precision_recall=sklearn.metrics.precision_recall_curve( + y_labels, y_scores, pos_label=pos_label + ), + ) - name: - The name of subset to load. + return summary, figures, curves - predictions_folder: - Folder where predictions for the dataset images has been previously - stored. - f1_thresh: - This number should come from - the training set or a separate validation set. Using a test set value - may bias your analysis. This number is also used to print the a priori - F1-score on the evaluated set. +def aggregate_summaries( + data: typing.Sequence[typing.Mapping[str, typing.Any]], fmt: str +) -> str: + """Tabulates summaries from multiple splits. - eer_thresh: - This number should come from - the training set or a separate validation set. Using a test set value - may bias your analysis. This number is used to print the a priori - EER. + This function can properly :py:mod:`tabulate` the various summaries + produced for all the splits in a prediction database. - steps: - number of threshold steps to consider when evaluating thresholds. + + Parameters + ---------- + data + An iterable over all summary data collected + fmt + One of the formats supported by :py:mod:`tabulate`. Returns ------- + A string containing the tabulated information + """ + + headers = list(data[0].keys()) + table = [[k[h] for h in headers] for k in data] + return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f") + - pred_data: - The loaded predictions for the specified subset. +def aggregate_roc( + data: typing.Mapping[str, typing.Any], + title: str = "ROC", +) -> matplotlib.figure.Figure: + """Aggregates ROC curves from multiple splits. - fig_scores: - Figure of the histogram distributions of true-positive/true-negative scores. + This function produces a single ROC plot for multiple curves generated per + split. - maxf1_threshold: - Threshold to achieve the highest possible F1-score for this dataset. - post_eer_threshold: - Threshold achieving Equal Error Rate for this dataset. + Parameters + ---------- + data + A dictionary mapping split names to ROC curve data produced by + :py:func:sklearn.metrics.roc_curve`. + + + Returns + ------- + A figure, containing the aggregated ROC plot. """ - predictions_path = os.path.join(predictions_folder, f"{name}.csv") - - if not os.path.exists(predictions_path): - predictions_path = predictions_folder - - # Load predictions - pred_data = pd.read_csv(predictions_path) - pred = torch.Tensor( - [ - eval(re.sub(" +", " ", x.replace("\n", "")).replace(" ", ",")) - for x in pred_data["likelihood"].values - ] - ).double() - gt = torch.Tensor( - [ - eval(re.sub(" +", " ", x.replace("\n", "")).replace(" ", ",")) - for x in pred_data["ground_truth"].values - ] - ).double() - - if pred.shape[1] == 1 and gt.shape[1] == 1: - pred = torch.flatten(pred) - gt = torch.flatten(gt) - - pred_data["likelihood"] = pred - pred_data["ground_truth"] = gt - - # Multiclass f1 score computation - if pred.ndim > 1: - auc = metrics.roc_auc_score(gt, pred) - logger.info("Evaluating multiclass classification") - logger.info(f"AUC: {auc}") - logger.info("F1 and EER are not implemented for multiclass") - - return None, None - - # Generate measures for each threshold - step_size = 1.0 / steps - data = [ - (index, threshold) + sample_measures_for_threshold(pred, gt, threshold) - for index, threshold in enumerate(numpy.arange(0.0, 1.0, step_size)) + fig, ax = plt.subplots(1, 1) + assert isinstance(fig, matplotlib.figure.Figure) + + # Names and bounds + ax.set_xlabel("1 - specificity") + ax.set_ylabel("Sensitivity") + ax.set_xlim([0.0, 1.0]) + ax.set_ylim([0.0, 1.0]) + ax.set_title(title) + + # we should see some of ax 1 ax + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["left"].set_position(("data", -0.015)) + ax.spines["bottom"].set_position(("data", -0.015)) + + ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) + + plt.tight_layout() + + lines = ["-", "--", "-.", ":"] + colors = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", ] + colorcycler = itertools.cycle(colors) + linecycler = itertools.cycle(lines) + + legend = [] + + for name, (fpr, tpr, _) in data.items(): + # plots roc curve + _auc = sklearn.metrics.auc(fpr, tpr) + label = f"{name} (AUC={_auc:.2f})" + color = next(colorcycler) + style = next(linecycler) + + (line,) = ax.plot(fpr, tpr, color=color, linestyle=style) + legend.append((line, label)) + + if len(legend) > 1: + ax.legend( + [k[0] for k in legend], + [k[1] for k in legend], + loc="lower right", + fancybox=True, + framealpha=0.7, + ) - data_df = pd.DataFrame( - data, - columns=( - "index", - "threshold", - "precision", - "recall", - "specificity", - "accuracy", - "jaccard", - "f1_score", - ), - ) - data_df = data_df.set_index("index") - """# Save evaluation csv - if output_folder is not None: - fullpath = os.path.join(output_folder, f"{name}.csv") - logger.info(f"Saving {fullpath}...") - os.makedirs(os.path.dirname(fullpath), exist_ok=True) - data_df.to_csv(fullpath)""" - - # Find max F1 score - f1_scores = numpy.asarray(data_df["f1_score"]) - thresholds = numpy.asarray(data_df["threshold"]) - - maxf1, maxf1_threshold = get_centered_maxf1(f1_scores, thresholds) - - logger.info( - f"Maximum F1-score of {maxf1:.5f}, achieved at " - f"threshold {maxf1_threshold:.3f} (chosen *a posteriori*)" - ) + return fig - # Find EER - neg_gt = pred_data.loc[pred_data.loc[:, "ground_truth"] == 0, :] - pos_gt = pred_data.loc[pred_data.loc[:, "ground_truth"] == 1, :] - post_eer_threshold = eer_threshold( - neg_gt["likelihood"], pos_gt["likelihood"] - ) - logger.info( - f"Equal error rate achieved at " - f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)" - ) +@contextlib.contextmanager +def _precision_recall_canvas() -> ( + Iterator[tuple[matplotlib.figure.Figure, matplotlib.figure.Axes]] +): + """Generates a canvas to draw precision-recall curves. - # Generate scores fig - fig_score, axes = plt.subplots(1) - fig_score.tight_layout(pad=3.0) + Works like a context manager, yielding a figure and an axes set in which + the precision-recall curves should be added to. The figure already + contains F1-ISO lines and is preset to a 0-1 square region. Once the + context is finished, ``fig.tight_layout()`` is called. - # Names and bounds - axes.set_xlabel("Score") - axes.set_ylabel("Normalized counts") - axes.set_xlim(0.0, 1.0) - neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len( - pred_data["likelihood"] - ) - pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len( - pred_data["likelihood"] - ) + Yields + ------ + figure + The figure that should be finally returned to the user + axes + An axis set where to precision-recall plots should be added to + """ - axes.hist( - [neg_gt["likelihood"], pos_gt["likelihood"]], - weights=[neg_weights, pos_weights], - bins=100, - color=["tab:blue", "tab:orange"], - label=["Negatives", "Positives"], - ) - axes.legend(prop={"size": 10}, loc="upper center") - axes.set_title(f"Score table for {name} subset") + fig, axes1 = plt.subplots(1) + assert isinstance(fig, matplotlib.figure.Figure) + assert isinstance(axes1, matplotlib.figure.Axes) + + # Names and bounds + axes1.set_xlabel("Recall") + axes1.set_ylabel("Precision") + axes1.set_xlim([0.0, 1.0]) + axes1.set_ylim([0.0, 1.0]) + + axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) + axes2 = axes1.twinx() + + # Annotates plot with F1-score iso-lines + f_scores = numpy.linspace(0.1, 0.9, num=9) + tick_locs = [] + tick_labels = [] + for f_score in f_scores: + x = numpy.linspace(0.01, 1) + y = f_score * x / (2 * x - f_score) + plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1) + tick_locs.append(y[-1]) + tick_labels.append("%.1f" % f_score) + axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False) + axes2.set_ylabel("iso-F", color="green", alpha=0.3) + axes2.set_ylim([0.0, 1.0]) + axes2.yaxis.set_label_coords(1.015, 0.97) + axes2.set_yticks(tick_locs) # notice these are invisible + for k in axes2.set_yticklabels(tick_labels): + k.set_color("green") + k.set_alpha(0.3) + k.set_size(8) # we should see some of axes 1 axes - axes.spines["right"].set_visible(False) - axes.spines["top"].set_visible(False) - axes.spines["left"].set_position(("data", -0.015)) - """if f1_thresh is not None and eer_thresh is not None: - # get the closest possible threshold we have - index = int(round(steps * f1_thresh)) - f1_a_priori = data_df["f1_score"][index] - actual_threshold = data_df["threshold"][index] - - logger.info( - f"F1-score of {f1_a_priori:.5f}, at threshold " - f"{actual_threshold:.3f} (chosen *a priori*)" - ) + axes1.spines["right"].set_visible(False) + axes1.spines["top"].set_visible(False) + axes1.spines["left"].set_position(("data", -0.015)) + axes1.spines["bottom"].set_position(("data", -0.015)) + + # we shouldn't see any of axes 2 axes + axes2.spines["right"].set_visible(False) + axes2.spines["top"].set_visible(False) + axes2.spines["left"].set_visible(False) + axes2.spines["bottom"].set_visible(False) + + # yield execution, lets user draw precision-recall plots, and the legend + # before tighteneing the layout + yield fig, axes1 + + plt.tight_layout() + - # Print the a priori EER threshold - logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}")""" +def aggregate_pr( + data: typing.Mapping[str, typing.Any], + title: str = "Precision-Recall Curve", +) -> matplotlib.figure.Figure: + """Aggregates PR curves from multiple splits. - return pred_data, fig_score, maxf1_threshold, post_eer_threshold + This function produces a single Precision-Recall plot for multiple curves + generated per split. The plot will be annotated with F1-score iso-lines (in + which the F1-score maintains the same value). + + + Parameters + ---------- + data + A dictionary mapping split names to ROC curve data produced by + :py:func:sklearn.metrics.precision_recall_curve`. + + + Returns + ------- + A figure, containing the aggregated PR plot. + """ + + lines = ["-", "--", "-.", ":"] + colors = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", + ] + colorcycler = itertools.cycle(colors) + linecycler = itertools.cycle(lines) + + with _precision_recall_canvas() as (fig, axes): + axes.set_title(title) + legend = [] + + for name, (prec, recall, _) in data.items(): + _auc = sklearn.metrics.auc(recall, prec) + label = f"{name} (AUC={_auc:.2f})" + color = next(colorcycler) + style = next(linecycler) + + (line,) = axes.plot(recall, prec, color=color, linestyle=style) + legend.append((line, label)) + + if len(legend) > 1: + axes.legend( + [k[0] for k in legend], + [k[1] for k in legend], + loc="lower left", + fancybox=True, + framealpha=0.7, + ) + + return fig diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 49e5f5acdbd1866e3cd10b323720d2eda3f2d0c2..9b6c62a177fe3b187ddc947d8b6dd430ef066be5 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -7,7 +7,12 @@ import logging import lightning.pytorch import torch.utils.data -from ..models.typing import Prediction +from ..models.typing import ( + BinaryPrediction, + BinaryPredictionSplit, + MultiClassPrediction, + MultiClassPredictionSplit, +) from .device import DeviceManager logger = logging.getLogger(__name__) @@ -18,9 +23,12 @@ def run( datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, ) -> ( - list[Prediction] - | list[list[Prediction]] - | dict[str, list[Prediction]] + list[BinaryPrediction] + | list[MultiClassPrediction] + | list[list[BinaryPrediction]] + | list[list[MultiClassPrediction]] + | BinaryPredictionSplit + | MultiClassPredictionSplit | None ): """Runs inference on input data, outputs csv files with predictions. diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index a664c4c6021768319e82fc0a7ff2a1e66cb8a161..ab66d1aa3ab183e29017e32d8957d0320381acda 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -205,7 +205,7 @@ class Alexnet(pl.LightningModule): return self._validation_loss(outputs, labels.float()) - def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index c416bb0298bd3782e864ce1b0bc26696d3dc68ea..c7def1b5b86a4a4e9e9ab474c7aa93140e0707b7 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -199,7 +199,7 @@ class Densenet(pl.LightningModule): return self._validation_loss(outputs, labels.float()) - def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index 4e0338f4e0b4ceb6faeef11bedba73a9b2b0204b..6a88d9675e09d3d6f0fb1e0a4d221ac948c0ae26 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -107,7 +107,7 @@ class LogisticRegression(pl.LightningModule): else: return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) diff --git a/src/ptbench/models/mlp.py b/src/ptbench/models/mlp.py index 102b384985c1814288e17265e3265015f500aacd..ac59ad6f6302355c352995cff9b4ee7208ed46e6 100644 --- a/src/ptbench/models/mlp.py +++ b/src/ptbench/models/mlp.py @@ -112,7 +112,7 @@ class MultiLayerPerceptron(pl.LightningModule): else: return {f"extra_validation_loss_{dataloader_idx}": validation_loss} - def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index c89a6d328a229dd7839956e2105e42541c1df0ea..112d7ef6cebacbb514725fe2946c22cd0c452c92 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -265,7 +265,7 @@ class Pasa(pl.LightningModule): return self._validation_loss(outputs, labels.float()) - def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): + def predict_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(batch[0]) probabilities = torch.sigmoid(outputs) return separate((probabilities, batch[1])) diff --git a/src/ptbench/models/separate.py b/src/ptbench/models/separate.py index febb0728e3db919b119ca76402dccc516ba41306..9568721169a58a69c1f54f57bcca05398cb02e3c 100644 --- a/src/ptbench/models/separate.py +++ b/src/ptbench/models/separate.py @@ -8,10 +8,12 @@ import typing import torch from ..data.typing import Sample -from .typing import Prediction +from .typing import BinaryPrediction, MultiClassPrediction -def _as_predictions(samples: typing.Iterable[Sample]) -> list[Prediction]: +def _as_predictions( + samples: typing.Iterable[Sample], +) -> list[BinaryPrediction | MultiClassPrediction]: """Takes a list of separated batch predictions and transform into a list of formal predictions. @@ -28,7 +30,7 @@ def _as_predictions(samples: typing.Iterable[Sample]) -> list[Prediction]: return [(v[1]["name"], v[1]["label"].item(), v[0].item()) for v in samples] -def separate(batch: Sample) -> list[Prediction]: +def separate(batch: Sample) -> list[BinaryPrediction | MultiClassPrediction]: """Separates a collated batch reconstituting its samples. This function implements the inverse of diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py index 3294cafdddc0d2415b787f8c1eb494ec31bdfdd6..6290315724d5e5fcd8f3a1d435771473838220e3 100644 --- a/src/ptbench/models/typing.py +++ b/src/ptbench/models/typing.py @@ -5,17 +5,23 @@ import typing -Checkpoint: typing.TypeAlias = typing.Mapping[str, typing.Any] +Checkpoint: typing.TypeAlias = typing.MutableMapping[str, typing.Any] """Definition of a lightning checkpoint.""" +BinaryPrediction: typing.TypeAlias = tuple[str, int, float] +"""Prediction: the sample name, the target, and the predicted value.""" -Prediction: typing.TypeAlias = tuple[ - str, int | typing.Sequence[int], float | typing.Sequence[float] +MultiClassPrediction: typing.TypeAlias = tuple[ + str, typing.Sequence[int], typing.Sequence[float] ] """Prediction: the sample name, the target, and the predicted value.""" +BinaryPredictionSplit: typing.TypeAlias = typing.Mapping[ + str, typing.Sequence[BinaryPrediction] +] +"""A series of predictions for different database splits.""" -PredictionSplit: typing.TypeAlias = typing.Mapping[ - str, typing.Sequence[Prediction] +MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[ + str, typing.Sequence[MultiClassPrediction] ] """A series of predictions for different database splits.""" diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index 7f3df261be9da8d8876709075b1d3cb52d8eb4f3..19afc037c4dfaa11a0f5fa32628cb08602b52b49 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -2,253 +2,168 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import os - -from collections import defaultdict +import pathlib import click -from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup -from matplotlib.backends.backend_pdf import PdfPages -from ..data.datamodule import ConcatDataModule -from ..data.typing import DataLoader -from ..utils.plot import precision_recall_f1iso, roc_curve -from ..utils.table import performance_table +from .click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -def _validate_threshold( - threshold: int | float | str, dataloader_dict: dict[str, DataLoader] -): - """Validates the user threshold selection. - - Parameters - ---------- - - threshold: - This number is used to define positives and negatives from - probability maps, and report F1-scores (a priori). It - should either come from the training set or a separate validation set - to avoid biasing the analysis. Optionally, if you provide a multi-set - dataset as input, this may also be the name of an existing set from - which the threshold will be estimated (highest F1-score) and then - applied to the subsequent sets. This number is also used to print - the test set F1-score a priori performance - - dataloader_dict: - Dictionary of set_name: dataloader, there set_name is the name of a dataset split - and dataloader is the torch dataloader for that split. - - Returns - ------- - - The parsed threshold. - """ - if threshold is None: - return 0.5 - - try: - # we try to convert it to float first - threshold = float(threshold) - if threshold < 0.0 or threshold > 1.0: - raise ValueError("Float thresholds must be within range [0.0, 1.0]") - except ValueError: - # it is a bit of text - assert dataset with name is available - if not isinstance(dataloader_dict, dict): - raise ValueError( - "Threshold should be a floating-point number " - "if your provide only a single dataset for evaluation" - ) - if threshold not in dataloader_dict: - raise ValueError( - f"Text thresholds should match dataset names, " - f"but {threshold} is not available among the datasets provided (" - f"({', '.join(dataloader_dict.keys())})" - ) - - return threshold - - @click.command( entry_point_group="ptbench.config", cls=ConfigCommand, epilog="""Examples: -\b - 1. Runs evaluation on an existing dataset configuration: +1. Runs evaluation on an existing prediction output: + + .. code:: sh + + ptbench evaluate -vv --predictions=path/to/predictions.json --output-folder=path/to/results + +2. Runs evaluation on an existing prediction output, tune threshold a priori on the `validation` set: - .. code:: sh + .. code:: sh - ptbench evaluate -vv montgomery --predictions-folder=path/to/predictions --output-folder=path/to/results + ptbench evaluate -vv --predictions=path/to/predictions.json --output-folder=path/to/results --threshold=validation """, ) @click.option( - "--output-folder", - "-o", - help="Path where to store the analysis result (created if does not exist)", - required=True, - default="results", - type=click.Path(), - cls=ResourceOption, -) -@click.option( - "--predictions-folder", + "--predictions", "-p", help="Path where predictions are currently stored", required=True, - type=click.Path(exists=True, file_okay=False, dir_okay=True), + type=click.Path( + file_okay=True, dir_okay=False, writable=True, path_type=pathlib.Path + ), cls=ResourceOption, ) @click.option( - "--datamodule", - "-d", - help="A lighting data module containing the training and validation sets.", + "--output-folder", + "-o", + help="Path where to store the analysis result (created if does not exist)", required=True, + default="results", + type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), cls=ResourceOption, ) @click.option( "--threshold", "-t", - help="This number is used to define positives and negatives from " - "probability maps, and report F1-scores (a priori). It " - "should either come from the training set or a separate validation set " - "to avoid biasing the analysis. Optionally, if you provide a multi-set " - "dataset as input, this may also be the name of an existing set from " - "which the threshold will be estimated (highest F1-score) and then " - "applied to the subsequent sets. This number is also used to print " - "the test set F1-score a priori performance", - default=None, + help="""This value is used to define positives and negatives from + probability outputs in predictions, and report performance measures on + **binary** classification tasks. It should either come from the training + set or a separate validation set to avoid biasing the analysis. + Optionally, if you provide a multi-split set of predictions as input, this + may also be the name of an existing split (e.g. ``validation``) from which + the threshold will be estimated (by calculating the threshold leading to + the highest F1-score on that set) and then applied to the subsequent + sets. This value is not used for multi-class classification tasks.""", + default=0.5, show_default=False, required=True, - cls=ResourceOption, -) -@click.option( - "--steps", - "-S", - help="This number is used to define the number of threshold steps to " - "consider when evaluating the highest possible F1-score on test data.", - default=1000, - show_default=True, - required=True, + type=click.STRING, cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def evaluate( - output_folder: str, - predictions_folder: str, - datamodule: ConcatDataModule, - threshold: int | float | str, - steps: int, - **_, + predictions: pathlib.Path, + output_folder: pathlib.Path, + threshold: str | float, + **_, # ignored ) -> None: - """Evaluates a CNN on a tuberculosis prediction task. - - Note: batch size of 1 is required on the predictions. - """ + """Evaluates predictions (from a model) on a binary classification task.""" - from ..engine.evaluator import run + import json + import typing - datamodule.set_chunk_size(1, 1) - datamodule.model_transforms = [] + import matplotlib.figure - datamodule.prepare_data() - datamodule.setup(stage="predict") + from matplotlib.backends.backend_pdf import PdfPages - dataloader = datamodule.predict_dataloader() + from ..engine.evaluator import ( + aggregate_pr, + aggregate_roc, + aggregate_summaries, + run_binary, + ) - threshold = _validate_threshold(threshold, dataloader) + with predictions.open("r") as f: + predict_data = json.load(f) - if isinstance(threshold, str): + if threshold in predict_data: + # it is the name of a split # first run evaluation for reference dataset - logger.info(f"Evaluating threshold on '{threshold}' set") - _, _, f1_threshold, eer_threshold = run( - name=threshold, - predictions_folder=predictions_folder, - steps=steps, - ) + from ..engine.evaluator import maxf1_threshold - if (f1_threshold is not None) and (eer_threshold is not None): - logger.info(f"Set --f1_threshold={f1_threshold:.5f}") - logger.info(f"Set --eer_threshold={eer_threshold:.5f}") + use_threshold = maxf1_threshold(predict_data[threshold]) + logger.info(f"Setting --threshold={use_threshold:.5f}") - elif isinstance(threshold, float): - f1_threshold = threshold - eer_threshold = f1_threshold else: - raise ValueError("Threshold value is neither a str or a float") - - results_dict = { # type: ignore - "pred_data": defaultdict(dict), - "fig_score": defaultdict(dict), - "maxf1_threshold": defaultdict(dict), - "post_eer_threshold": defaultdict(dict), - } - - for k in dataloader.keys(): - if k.startswith("_"): - logger.info(f"Skipping dataset '{k}' (not to be evaluated)") - continue - logger.info(f"Analyzing '{k}' set...") - pred_data, fig_score, maxf1_threshold, post_eer_threshold = run( - k, - predictions_folder, - f1_thresh=f1_threshold, - eer_thresh=eer_threshold, - steps=steps, + # we try to convert it to float and complain if that is not possible + try: + use_threshold = float(threshold) + except ValueError: + raise click.BadParameter( + f"""The value of --threshold=`{threshold}` does not match one + of the database split names ({', '.join(predict_data.keys())}) + or can be converted to float. Check your input.""" + ) + + results: dict[ + str, + tuple[ + dict[str, typing.Any], + dict[str, matplotlib.figure.Figure], + dict[str, typing.Any], + ], + ] = dict() + for k, v in predict_data.items(): + logger.info(f"Analyzing split `{k}`...") + results[k] = run_binary( + name=k, + predictions=v, + threshold_a_priori=use_threshold, ) - results_dict["pred_data"][k] = pred_data - results_dict["fig_score"][k] = fig_score - results_dict["maxf1_threshold"][k] = maxf1_threshold - results_dict["post_eer_threshold"][k] = post_eer_threshold + rows = [v[0] for v in results.values()] + table = aggregate_summaries(rows, fmt="rst") + click.echo(table) if output_folder is not None: - output_scores = os.path.join(output_folder, "scores.pdf") - - if output_scores is not None: - output_scores = os.path.realpath(output_scores) - logger.info(f"Creating and saving scores at {output_scores}...") - os.makedirs(os.path.dirname(output_scores), exist_ok=True) - - score_pdf = PdfPages(output_scores) - - for fig in results_dict["fig_score"].values(): - score_pdf.savefig(fig) - score_pdf.close() - - data = {} - for subset_name in dataloader.keys(): - data[subset_name] = { - "df": results_dict["pred_data"][subset_name], - "threshold": results_dict["post_eer_threshold"][threshold] - if isinstance(threshold, str) - else eer_threshold, - "threshold_type": f"posteriori [{threshold}]" - if isinstance(threshold, str) - else "priori", + output_folder.mkdir(parents=True, exist_ok=True) + + table_path = output_folder / "summary.rst" + + logger.info(f"Saving measures at `{table_path}`...") + with table_path.open("w") as f: + f.write(table) + + figure_path = output_folder / "plots.pdf" + logger.info(f"Saving figures at `{figure_path}`...") + + with PdfPages(figure_path) as pdf: + pr_curves = { + k: v[2]["precision_recall"] for k, v in results.items() + } + pr_fig = aggregate_pr(pr_curves) + pdf.savefig(pr_fig) + + roc_curves = {k: v[2]["roc"] for k, v in results.items()} + roc_fig = aggregate_roc(roc_curves) + pdf.savefig(roc_fig) + + # order ready-to-save figures by type instead of split + figures = {k: v[1] for k, v in results.items()} + keys = next(iter(figures.values())).keys() + figures_by_type = { + k: [v[k] for v in figures.values()] for k in keys } - output_figure = os.path.join(output_folder, "plots.pdf") - - if output_figure is not None: - output_figure = os.path.realpath(output_figure) - logger.info(f"Creating and saving plots at {output_figure}...") - os.makedirs(os.path.dirname(output_figure), exist_ok=True) - pdf = PdfPages(output_figure) - pdf.savefig(precision_recall_f1iso(data)) - pdf.savefig(roc_curve(data)) - pdf.close() - - output_table = os.path.join(output_folder, "table.txt") - logger.info("Tabulating performance summary...") - table = performance_table(data, "rst") - click.echo(table) - if output_table is not None: - output_table = os.path.realpath(output_table) - logger.info(f"Saving table at {output_table}...") - os.makedirs(os.path.dirname(output_table), exist_ok=True) - with open(output_table, "w") as f: - f.write(table) + for group_figures in figures_by_type.values(): + for f in group_figures: + pdf.savefig(f) diff --git a/src/ptbench/utils/download.py b/src/ptbench/utils/download.py deleted file mode 100644 index 911d5f916c84d13bc59e42a57b1bcd75c83c292e..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/download.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import logging -import tempfile -import urllib.request - -from tqdm import tqdm - -logger = logging.getLogger(__name__) - - -def download_to_tempfile(url, progress=False): - """Downloads a file to a temporary named file and returns it. - - Parameters - ---------- - - url : str - The URL pointing to the file to download - - progress : :py:class:`bool`, Optional - If a progress bar should be displayed for downloading the URL. - - - Returns - ------- - - f : :py:func:`tempfile.NamedTemporaryFile` - A named temporary file that contains the downloaded URL - """ - file_size = 0 - response = urllib.request.urlopen(url) - meta = response.info() - if hasattr(meta, "getheaders"): - content_length = meta.getheaders("Content-Length") - else: - content_length = meta.get_all("Content-Length") - - if content_length is not None and len(content_length) > 0: - file_size = int(content_length[0]) - - progress &= bool(file_size) - - f = tempfile.NamedTemporaryFile() - - with tqdm(total=file_size, disable=not progress) as pbar: - while True: - buffer = response.read(8192) - if len(buffer) == 0: - break - f.write(buffer) - pbar.update(len(buffer)) - - f.flush() - f.seek(0) - return f diff --git a/src/ptbench/utils/grad_cams.py b/src/ptbench/utils/grad_cams.py deleted file mode 100644 index 11da258ae13f1468d50c32852b13a18ac1513124..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/grad_cams.py +++ /dev/null @@ -1,104 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# SPDX-FileContributor: Kazuto Nakashima <k nakashima@irvs.ait.kyushu-u.ac.jp> -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import torch - -from torch.nn import functional as F - - -class BaseWrapper: - def __init__(self, model): - super().__init__() - self.device = next(model.parameters()).device - self.model_with_norm = model - self.model = model.model - self.handlers = [] # a set of hook function handlers - - def _encode_one_hot(self, ids): - one_hot = torch.zeros_like(self.logits).to(self.device) - one_hot.scatter_(1, ids, 1.0) - return one_hot - - def forward(self, image): - self.image_shape = image.shape[2:] - self.logits = self.model_with_norm(image) - self.probs = torch.sigmoid(self.logits) - return self.probs.sort(dim=1, descending=True) # ordered results - - def backward(self, ids): - """Class-specific backpropagation.""" - one_hot = self._encode_one_hot(ids) - self.model_with_norm.zero_grad() - self.logits.backward(gradient=one_hot, retain_graph=True) - - def generate(self): - raise NotImplementedError - - def remove_hook(self): - """Remove all the forward/backward hook functions.""" - for handle in self.handlers: - handle.remove() - - -class GradCAM(BaseWrapper): - """ - "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" - https://arxiv.org/pdf/1610.02391.pdf - Look at Figure 2 on page 4 - """ - - def __init__(self, model, candidate_layers=None): - super().__init__(model) - self.fmap_pool = {} - self.grad_pool = {} - self.candidate_layers = candidate_layers # list - - def save_fmaps(key): - def forward_hook(module, input, output): - self.fmap_pool[key] = output.detach() - - return forward_hook - - def save_grads(key): - def backward_hook(module, grad_in, grad_out): - self.grad_pool[key] = grad_out[0].detach() - - return backward_hook - - # If any candidates are not specified, the hook is registered to all the layers. - for name, module in self.model.named_modules(): - if self.candidate_layers is None or name in self.candidate_layers: - self.handlers.append( - module.register_forward_hook(save_fmaps(name)) - ) - self.handlers.append( - module.register_backward_hook(save_grads(name)) - ) - - def _find(self, pool, target_layer): - if target_layer in pool.keys(): - return pool[target_layer] - else: - raise ValueError(f"Invalid layer name: {target_layer}") - - def generate(self, target_layer): - fmaps = self._find(self.fmap_pool, target_layer) - grads = self._find(self.grad_pool, target_layer) - weights = F.adaptive_avg_pool2d(grads, 1) - - gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) - gcam = F.relu(gcam) - gcam = F.interpolate( - gcam, self.image_shape, mode="bilinear", align_corners=False - ) - - B, C, H, W = gcam.shape - gcam = gcam.view(B, -1) - gcam -= gcam.min(dim=1, keepdim=True)[0] - gcam /= gcam.max(dim=1, keepdim=True)[0] - gcam = gcam.view(B, C, H, W) - - return gcam diff --git a/src/ptbench/utils/image.py b/src/ptbench/utils/image.py deleted file mode 100644 index 363a8309f4581dfdda42124c0562d3bf840904af..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/image.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import os - -from typing import Union - -import torch - -from PIL.Image import Image -from torchvision import transforms - - -def save_image(img: Union[torch.Tensor, Image], filepath: str) -> None: - """Saves a PIL image or a tensor as an image at the specified destination. - - Parameters - ---------- - - img: - A torch.Tensor or PIL.Image to save - - filepath: - The file in which to save the image. The format is inferred from the file extension, or defaults to png if not specified. - """ - - if isinstance(img, torch.Tensor): - img = transforms.ToPILImage()(img) - - root, ext = os.path.splitext(filepath) - - if len(ext) == 0: - filepath = filepath + ".png" - - img.save(filepath) diff --git a/src/ptbench/utils/measure.py b/src/ptbench/utils/measure.py deleted file mode 100644 index f0031c0c0df555280b6d5e10f94df7ec9cf7fc35..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/measure.py +++ /dev/null @@ -1,398 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from collections import deque - -import numpy -import scipy.special -import torch - - -class SmoothedValue: - """Track a series of values and provide access to smoothed values over a - window or the global series average.""" - - def __init__(self, window_size=20): - self.deque = deque(maxlen=window_size) - - def update(self, value): - self.deque.append(value) - - @property - def median(self): - d = torch.tensor(list(self.deque)) - return d.median().item() - - @property - def avg(self): - d = torch.tensor(list(self.deque)) - return d.mean().item() - - -def tricky_division(n, d): - """Divides n by d. - - Returns 0.0 in case of a division by zero - """ - return n / (d + (d == 0)) - - -def base_measures(tp, fp, tn, fn): - """Calculates measures from true/false positive and negative counts. - - This function can return standard machine learning measures from true and - false positive counts of positives and negatives. For a thorough look into - these and alternate names for the returned values, please check Wikipedia's - entry on `Precision and Recall - <https://en.wikipedia.org/wiki/Precision_and_recall>`_. - - - Parameters - ---------- - - tp : int - True positive count, AKA "hit" - - fp : int - False positive count, AKA, "correct rejection" - - tn : int - True negative count, AKA "false alarm", or "Type I error" - - fn : int - False Negative count, AKA "miss", or "Type II error" - - - Returns - ------- - - precision : float - P, AKA positive predictive value (PPV). It corresponds arithmetically - to ``tp/(tp+fp)``. In the case ``tp+fp == 0``, this function returns - zero for precision. - - recall : float - R, AKA sensitivity, hit rate, or true positive rate (TPR). It - corresponds arithmetically to ``tp/(tp+fn)``. In the special case - where ``tp+fn == 0``, this function returns zero for recall. - - specificity : float - S, AKA selectivity or true negative rate (TNR). It - corresponds arithmetically to ``tn/(tn+fp)``. In the special case - where ``tn+fp == 0``, this function returns zero for specificity. - - accuracy : float - A, see `Accuracy - <https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is - the proportion of correct predictions (both true positives and true - negatives) among the total number of pixels examined. It corresponds - arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes - both true-negatives and positives in the numerator, what makes it - sensitive to data or regions without annotations. - - jaccard : float - J, see `Jaccard Index or Similarity - <https://en.wikipedia.org/wiki/Jaccard_index>`_. It corresponds - arithmetically to ``tp/(tp+fp+fn)``. In the special case where - ``tn+fp+fn == 0``, this function returns zero for the Jaccard index. - The Jaccard index depends on a TP-only numerator, similarly to the F1 - score. For regions where there are no annotations, the Jaccard index - will always be zero, irrespective of the model output. Accuracy may be - a better proxy if one needs to consider the true abscence of - annotations in a region as part of the measure. - - f1_score : float - F1, see `F1-score <https://en.wikipedia.org/wiki/F1_score>`_. It - corresponds arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``. - In the special case where ``P+R == (2*tp+fp+fn) == 0``, this function - returns zero for the Jaccard index. The F1 or Dice score depends on a - TP-only numerator, similarly to the Jaccard index. For regions where - there are no annotations, the F1-score will always be zero, - irrespective of the model output. Accuracy may be a better proxy if - one needs to consider the true abscence of annotations in a region as - part of the measure. - """ - return ( - tricky_division(tp, tp + fp), # precision - tricky_division(tp, tp + fn), # recall - tricky_division(tn, fp + tn), # specificity - tricky_division(tp + tn, tp + fp + fn + tn), # accuracy - tricky_division(tp, tp + fp + fn), # jaccard index - tricky_division(2 * tp, (2 * tp) + fp + fn), # f1-score - ) - - -def beta_credible_region(successes, failures, lambda_, coverage): - """Returns the mode, upper and lower bounds of the equal-tailed credible - region of a probability estimate following Bernoulli trials. - - This implemetnation is based on [GOUTTE-2005]_. It assumes :math:`k` - successes and :math:`l` failures (:math:`n = k+l` total trials) are issued - from a series of Bernoulli trials (likelihood is binomial). The posterior - is derivated using the Bayes Theorem with a beta prior. As there is no - reason to favour high vs. low precision, we use a symmetric Beta prior - (:math:`\\alpha=\\beta`): - - .. math:: - - P(p|k,n) &= \\frac{P(k,n|p)P(p)}{P(k,n)} \\\\ - P(p|k,n) &= \\frac{\\frac{n!}{k!(n-k)!}p^{k}(1-p)^{n-k}P(p)}{P(k)} \\\\ - P(p|k,n) &= \\frac{1}{B(k+\\alpha, n-k+\beta)}p^{k+\\alpha-1}(1-p)^{n-k+\\beta-1} \\\\ - P(p|k,n) &= \\frac{1}{B(k+\\alpha, n-k+\\alpha)}p^{k+\\alpha-1}(1-p)^{n-k+\\alpha-1} - - The mode for this posterior (also the maximum a posteriori) is: - - .. math:: - - \\text{mode}(p) = \\frac{k+\\lambda-1}{n+2\\lambda-2} - - Concretely, the prior may be flat (all rates are equally likely, - :math:`\\lambda=1`) or we may use Jeoffrey's prior - (:math:`\\lambda=0.5`), that is invariant through re-parameterisation. - Jeffrey's prior indicate that rates close to zero or one are more likely. - - The mode above works if :math:`k+{\\alpha},n-k+{\\alpha} > 1`, which is - usually the case for a resonably well tunned system, with more than a few - samples for analysis. In the limit of the system performance, :math:`k` - may be 0, which will make the mode become zero. - - For our purposes, it may be more suitable to represent :math:`n = k + l`, - with :math:`k`, the number of successes and :math:`l`, the number of - failures in the binomial experiment, and find this more suitable - representation: - - .. math:: - - P(p|k,l) &= \\frac{1}{B(k+\\alpha, l+\\alpha)}p^{k+\\alpha-1}(1-p)^{l+\\alpha-1} \\\\ - \\text{mode}(p) &= \\frac{k+\\lambda-1}{k+l+2\\lambda-2} - - This can be mapped to most rates calculated in the context of binary - classification this way: - - * Precision or Positive-Predictive Value (PPV): p = TP/(TP+FP), so k=TP, l=FP - * Recall, Sensitivity, or True Positive Rate: r = TP/(TP+FN), so k=TP, l=FN - * Specificity or True Negative Rage: s = TN/(TN+FP), so k=TN, l=FP - * F1-score: f1 = 2TP/(2TP+FP+FN), so k=2TP, l=FP+FN - * Accuracy: acc = TP+TN/(TP+TN+FP+FN), so k=TP+TN, l=FP+FN - * Jaccard: j = TP/(TP+FP+FN), so k=TP, l=FP+FN - - Contrary to frequentist approaches, in which one can only - say that if the test were repeated an infinite number of times, - and one constructed a confidence interval each time, then X% - of the confidence intervals would contain the true rate, here - we can say that given our observed data, there is a X% probability - that the true value of :math:`k/n` falls within the provided - interval. - - - .. note:: - - For a disambiguation with Confidence Interval, read - https://en.wikipedia.org/wiki/Credible_interval. - - - Parameters - ========== - - successes : int - Number of successes observed on the experiment - - failures : int - Number of failures observed on the experiment - - lambda__ : :py:class:`float`, Optional - The parameterisation of the Beta prior to consider. Use - :math:`\\lambda=1` for a flat prior. Use :math:`\\lambda=0.5` for - Jeffrey's prior (the default). - - coverage : :py:class:`float`, Optional - A floating-point number between 0 and 1.0 indicating the - coverage you're expecting. A value of 0.95 will ensure 95% - of the area under the probability density of the posterior - is covered by the returned equal-tailed interval. - - - Returns - ======= - - mean : float - The mean of the posterior distribution - - mode : float - The mode of the posterior distribution - - lower, upper: float - The lower and upper bounds of the credible region - """ - # we return the equally-tailed range - right = (1.0 - coverage) / 2 # half-width in each side - lower = scipy.special.betaincinv( - successes + lambda_, failures + lambda_, right - ) - upper = scipy.special.betaincinv( - successes + lambda_, failures + lambda_, 1.0 - right - ) - - # evaluate mean and mode (https://en.wikipedia.org/wiki/Beta_distribution) - alpha = successes + lambda_ - beta = failures + lambda_ - - E = alpha / (alpha + beta) - - # the mode of a beta distribution is a bit tricky - if alpha > 1 and beta > 1: - mode = (alpha - 1) / (alpha + beta - 2) - elif alpha == 1 and beta == 1: - # In the case of precision, if the threshold is close to 1.0, both TP - # and FP can be zero, which may cause this condition to be reached, if - # the prior is exactly 1 (flat prior). This is a weird situation, - # because effectively we are trying to compute the posterior when the - # total number of experiments is zero. So, only the prior counts - but - # the prior is flat, so we should just pick a value. We choose the - # middle of the range. - mode = 0.0 # any value would do, we just pick this one - elif alpha <= 1 and beta > 1: - mode = 0.0 - elif alpha > 1 and beta <= 1: - mode = 1.0 - else: # elif alpha < 1 and beta < 1: - # in the case of precision, if the threshold is close to 1.0, both TP - # and FP can be zero, which may cause this condition to be reached, if - # the prior is smaller than 1. This is a weird situation, because - # effectively we are trying to compute the posterior when the total - # number of experiments is zero. So, only the prior counts - but the - # prior is bimodal, so we should just pick a value. We choose the - # left of the range. - mode = 0.0 # could also be 1.0 as the prior is bimodal - - return E, mode, lower, upper - - -def bayesian_measures(tp, fp, tn, fn, lambda_, coverage): - r"""Calculates mean and mode from true/false positive and negative counts - with credible regions. - - This function can return bayesian estimates of standard machine learning - measures from true and false positive counts of positives and negatives. - For a thorough look into these and alternate names for the returned values, - please check Wikipedia's entry on `Precision and Recall - <https://en.wikipedia.org/wiki/Precision_and_recall>`_. See - :py:func:`beta_credible_region` for details on the calculation of returned - values. - - - Parameters - ---------- - - tp : int - True positive count, AKA "hit" - - fp : int - False positive count, AKA "false alarm", or "Type I error" - - tn : int - True negative count, AKA "correct rejection" - - fn : int - False Negative count, AKA "miss", or "Type II error" - - lambda_ : float - The parameterisation of the Beta prior to consider. Use - :math:`\lambda=1` for a flat prior. Use :math:`\lambda=0.5` for - Jeffrey's prior. - - coverage : float - A floating-point number between 0 and 1.0 indicating the - coverage you're expecting. A value of 0.95 will ensure 95% - of the area under the probability density of the posterior - is covered by the returned equal-tailed interval. - - - - Returns - ------- - - precision : (float, float, float, float) - P, AKA positive predictive value (PPV), mean, mode and credible - intervals (95% CI). It corresponds arithmetically - to ``tp/(tp+fp)``. - - recall : (float, float, float, float) - R, AKA sensitivity, hit rate, or true positive rate (TPR), mean, mode - and credible intervals (95% CI). It corresponds arithmetically to - ``tp/(tp+fn)``. - - specificity : (float, float, float, float) - S, AKA selectivity or true negative rate (TNR), mean, mode and credible - intervals (95% CI). It corresponds arithmetically to ``tn/(tn+fp)``. - - accuracy : (float, float, float, float) - A, mean, mode and credible intervals (95% CI). See `Accuracy - <https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers>`_. is - the proportion of correct predictions (both true positives and true - negatives) among the total number of pixels examined. It corresponds - arithmetically to ``(tp+tn)/(tp+tn+fp+fn)``. This measure includes - both true-negatives and positives in the numerator, what makes it - sensitive to data or regions without annotations. - - jaccard : (float, float, float, float) - J, mean, mode and credible intervals (95% CI). See `Jaccard Index or - Similarity <https://en.wikipedia.org/wiki/Jaccard_index>`_. It - corresponds arithmetically to ``tp/(tp+fp+fn)``. The Jaccard index - depends on a TP-only numerator, similarly to the F1 score. For regions - where there are no annotations, the Jaccard index will always be zero, - irrespective of the model output. Accuracy may be a better proxy if - one needs to consider the true abscence of annotations in a region as - part of the measure. - - f1_score : (float, float, float, float) - F1, mean, mode and credible intervals (95% CI). See `F1-score - <https://en.wikipedia.org/wiki/F1_score>`_. It corresponds - arithmetically to ``2*P*R/(P+R)`` or ``2*tp/(2*tp+fp+fn)``. The F1 or - Dice score depends on a TP-only numerator, similarly to the Jaccard - index. For regions where there are no annotations, the F1-score will - always be zero, irrespective of the model output. Accuracy may be a - better proxy if one needs to consider the true abscence of annotations - in a region as part of the measure. - """ - return ( - beta_credible_region(tp, fp, lambda_, coverage), # precision - beta_credible_region(tp, fn, lambda_, coverage), # recall - beta_credible_region(tn, fp, lambda_, coverage), # specificity - beta_credible_region(tp + tn, fp + fn, lambda_, coverage), # accuracy - beta_credible_region(tp, fp + fn, lambda_, coverage), # jaccard index - beta_credible_region(2 * tp, fp + fn, lambda_, coverage), # f1-score - ) - - -def get_centered_maxf1(f1_scores, thresholds): - """Return the centered max F1 score threshold when multiple threshold give - the same max F1 score. - - Parameters - ---------- - - f1_scores : numpy.ndarray - 1D array of f1 scores - - thresholds : numpy.ndarray - 1D array of thresholds - - Returns - ------- - - max F1 score: float - - threshold: float - """ - maxf1 = f1_scores.max() - maxf1_indices = numpy.where(f1_scores == maxf1)[0] - - # If multiple thresholds give the same max F1 score - if len(maxf1_indices) > 1: - mean_maxf1_index = int(round(numpy.mean(maxf1_indices))) - else: - mean_maxf1_index = maxf1_indices[0] - - return maxf1, thresholds[mean_maxf1_index] diff --git a/src/ptbench/utils/model_serialization.py b/src/ptbench/utils/model_serialization.py deleted file mode 100644 index 8523115e0119ff6ff561ac3dd132ebc69e6bcf38..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/model_serialization.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-FileCopyrightText: Copyright Facebook, Inc. and its affiliates. All Rights Reserved. -# -# SPDX-License-Identifier: GPL-3.0-or-later - -# Original code from: https://github.com/facebookresearch/maskrcnn-benchmark - -import logging - -from collections import OrderedDict - -logger = logging.getLogger(__name__) - -import torch - - -def align_and_update_state_dicts(model_state_dict, loaded_state_dict): - """ - Strategy: suppose that the models that we will create will have prefixes appended - to each of its keys, for example due to an extra level of nesting that the original - pre-trained weights from ImageNet won't contain. For example, model.state_dict() - might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains - res2.conv1.weight. We thus want to match both parameters together. - For that, we look for each model weight, look among all loaded keys if there is one - that is a suffix of the current weight name, and use it if that's the case. - If multiple matches exist, take the one with longest size - of the corresponding name. For example, for the same model as before, the pretrained - weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, - we want to match backbone[0].body.conv1.weight to conv1.weight, and - backbone[0].body.res2.conv1.weight to res2.conv1.weight. - """ - current_keys = sorted(list(model_state_dict.keys())) - loaded_keys = sorted(list(loaded_state_dict.keys())) - # get a matrix of string matches, where each (i, j) entry correspond to the size of the - # loaded_key string, if it matches - match_matrix = [ - len(j) if i.endswith(j) else 0 - for i in current_keys - for j in loaded_keys - ] - match_matrix = torch.as_tensor(match_matrix).view( - len(current_keys), len(loaded_keys) - ) - max_match_size, idxs = match_matrix.max(1) - # remove indices that correspond to no-match - idxs[max_match_size == 0] = -1 - - # used for logging - max_size = max([len(key) for key in current_keys]) if current_keys else 1 - max_size_loaded = ( - max([len(key) for key in loaded_keys]) if loaded_keys else 1 - ) - log_str_template = "{: <{}} loaded from {: <{}} of shape {}" - for idx_new, idx_old in enumerate(idxs.tolist()): - if idx_old == -1: - continue - key = current_keys[idx_new] - key_old = loaded_keys[idx_old] - model_state_dict[key] = loaded_state_dict[key_old] - logger.debug( - log_str_template.format( - key, - max_size, - key_old, - max_size_loaded, - tuple(loaded_state_dict[key_old].shape), - ) - ) - - -def strip_prefix_if_present(state_dict, prefix): - keys = sorted(state_dict.keys()) - if not all(key.startswith(prefix) for key in keys): - return state_dict - stripped_state_dict = OrderedDict() - for key, value in state_dict.items(): - stripped_state_dict[key.replace(prefix, "")] = value - return stripped_state_dict - - -def load_state_dict(model, loaded_state_dict): - model_state_dict = model.state_dict() - # if the state_dict comes from a model that was wrapped in a - # DataParallel or DistributedDataParallel during serialization, - # remove the "module" prefix before performing the matching - loaded_state_dict = strip_prefix_if_present( - loaded_state_dict, prefix="module." - ) - align_and_update_state_dicts(model_state_dict, loaded_state_dict) - - # use strict loading - model.load_state_dict(model_state_dict) diff --git a/src/ptbench/utils/model_zoo.py b/src/ptbench/utils/model_zoo.py deleted file mode 100644 index c1c33e2dda255d7622e7b9d4c82771d0f93d7f07..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/model_zoo.py +++ /dev/null @@ -1,117 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -# Adapted from: -# https://github.com/pytorch/pytorch/blob/master/torch/hub.py -# https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/checkpoint.py - -import hashlib -import os -import re -import shutil -import sys -import tempfile - -from urllib.parse import urlparse -from urllib.request import urlopen - -from tqdm import tqdm - -modelurls = { - "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", - "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth", - "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", - "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", - "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", - "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", - "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", - "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", -} -"""URLs of pre-trained models (backbones)""" - - -def download_url_to_file(url, dst, hash_prefix, progress): - file_size = None - u = urlopen(url) - meta = u.info() - if hasattr(meta, "getheaders"): - content_length = meta.getheaders("Content-Length") - else: - content_length = meta.get_all("Content-Length") - if content_length is not None and len(content_length) > 0: - file_size = int(content_length[0]) - - f = tempfile.NamedTemporaryFile(delete=False) - try: - if hash_prefix is not None: - sha256 = hashlib.sha256() - with tqdm(total=file_size, disable=not progress) as pbar: - while True: - buffer = u.read(8192) - if len(buffer) == 0: - break - f.write(buffer) - if hash_prefix is not None: - sha256.update(buffer) - pbar.update(len(buffer)) - - f.close() - if hash_prefix is not None: - digest = sha256.hexdigest() - if digest[: len(hash_prefix)] != hash_prefix: - raise RuntimeError( - 'invalid hash value (expected "{}", got "{}")'.format( - hash_prefix, digest - ) - ) - shutil.move(f.name, dst) - finally: - f.close() - if os.path.exists(f.name): - os.remove(f.name) - - -HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") - - -def cache_url(url, model_dir=None, progress=True): - r"""Loads the Torch serialized object at the given URL. - - If the object is already present in `model_dir`, it's deserialized and - returned. The filename part of the URL should follow the naming convention - ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more - digits of the SHA256 hash of the contents of the file. The hash is used to - ensure unique names and to verify the contents of the file. - The default value of `model_dir` is ``$TORCH_HOME/models`` where - ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be - overridden with the ``$TORCH_MODEL_ZOO`` environment variable. - Args: - url (string): URL of the object to download - model_dir (string, optional): directory in which to save the object - progress (bool, optional): whether or not to display a progress bar to stderr - """ - if model_dir is None: - torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch")) - model_dir = os.getenv( - "TORCH_MODEL_ZOO", os.path.join(torch_home, "models") - ) - if not os.path.exists(model_dir): - os.makedirs(model_dir) - parts = urlparse(url) - filename = os.path.basename(parts.path) - - cached_file = os.path.join(model_dir, filename) - if not os.path.exists(cached_file): - sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') - hash_prefix = HASH_REGEX.search(filename) - if hash_prefix is not None: - hash_prefix = hash_prefix.group(1) - download_url_to_file(url, cached_file, hash_prefix, progress=progress) - - return cached_file diff --git a/src/ptbench/utils/plot.py b/src/ptbench/utils/plot.py deleted file mode 100644 index 410a086af4e6a48c6f4cc96f260569f165b82d2b..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/plot.py +++ /dev/null @@ -1,339 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import contextlib - -from itertools import cycle - -import matplotlib -import numpy - -from sklearn.metrics import auc -from sklearn.metrics import precision_recall_curve as pr_curve -from sklearn.metrics import roc_curve as r_curve - -matplotlib.use("agg") -import logging - -import matplotlib.pyplot as plt - -logger = logging.getLogger(__name__) - - -@contextlib.contextmanager -def _precision_recall_canvas(title=None): - """Generates a canvas to draw precision-recall curves. - - Works like a context manager, yielding a figure and an axes set in which - the precision-recall curves should be added to. The figure already - contains F1-ISO lines and is preset to a 0-1 square region. Once the - context is finished, ``fig.tight_layout()`` is called. - - - Parameters - ---------- - - title : :py:class:`str`, Optional - Optional title to add to this plot - - - Yields - ------ - - figure : matplotlib.figure.Figure - The figure that should be finally returned to the user - - axes : matplotlib.figure.Axes - An axis set where to precision-recall plots should be added to - """ - fig, axes1 = plt.subplots(1) - - # Names and bounds - axes1.set_xlabel("Recall") - axes1.set_ylabel("Precision") - axes1.set_xlim([0.0, 1.0]) - axes1.set_ylim([0.0, 1.0]) - - if title is not None: - axes1.set_title(title) - - axes1.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) - axes2 = axes1.twinx() - - # Annotates plot with F1-score iso-lines - f_scores = numpy.linspace(0.1, 0.9, num=9) - tick_locs = [] - tick_labels = [] - for f_score in f_scores: - x = numpy.linspace(0.01, 1) - y = f_score * x / (2 * x - f_score) - (l,) = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.1) - tick_locs.append(y[-1]) - tick_labels.append("%.1f" % f_score) - axes2.tick_params(axis="y", which="both", pad=0, right=False, left=False) - axes2.set_ylabel("iso-F", color="green", alpha=0.3) - axes2.set_ylim([0.0, 1.0]) - axes2.yaxis.set_label_coords(1.015, 0.97) - axes2.set_yticks(tick_locs) # notice these are invisible - for k in axes2.set_yticklabels(tick_labels): - k.set_color("green") - k.set_alpha(0.3) - k.set_size(8) - - # we should see some of axes 1 axes - axes1.spines["right"].set_visible(False) - axes1.spines["top"].set_visible(False) - axes1.spines["left"].set_position(("data", -0.015)) - axes1.spines["bottom"].set_position(("data", -0.015)) - - # we shouldn't see any of axes 2 axes - axes2.spines["right"].set_visible(False) - axes2.spines["top"].set_visible(False) - axes2.spines["left"].set_visible(False) - axes2.spines["bottom"].set_visible(False) - - # yield execution, lets user draw precision-recall plots, and the legend - # before tighteneing the layout - yield fig, axes1 - - plt.tight_layout() - - -def precision_recall_f1iso(data): - """Creates a precision-recall plot. - - This function creates and returns a Matplotlib figure with a - precision-recall plot. The plot will be annotated with F1-score - iso-lines (in which the F1-score maintains the same value). - - - Parameters - ---------- - - data : dict - A dictionary in which keys are strings defining plot labels and values - are dictionaries with two entries: - - * ``df``: :py:class:`pandas.DataFrame` - - A dataframe that is produced by our predictor engine containing - the following columns: ``filename``, ``likelihood``, - ``ground_truth``. - - * ``threshold``: :py:class:`list` - - A threshold for each set. Not used here. - - - Returns - ------- - - figure : matplotlib.figure.Figure - A matplotlib figure you can save or display (uses an ``agg`` backend) - """ - lines = ["-", "--", "-.", ":"] - colors = [ - "#1f77b4", - "#ff7f0e", - "#2ca02c", - "#d62728", - "#9467bd", - "#8c564b", - "#e377c2", - "#7f7f7f", - "#bcbd22", - "#17becf", - ] - colorcycler = cycle(colors) - linecycler = cycle(lines) - - with _precision_recall_canvas(title=None) as (fig, axes): - legend = [] - - for name, value in data.items(): - df = value["df"] - - # plots Recall/Precision curve - prec, recall, _ = pr_curve(df["ground_truth"], df["likelihood"]) - _auc = auc(recall, prec) - label = f"{name} (AUC={_auc:.2f})" - color = next(colorcycler) - style = next(linecycler) - - (line,) = axes.plot(recall, prec, color=color, linestyle=style) - legend.append((line, label)) - - if len(label) > 1: - axes.legend( - [k[0] for k in legend], - [k[1] for k in legend], - loc="lower left", - fancybox=True, - framealpha=0.7, - ) - - return fig - - -def roc_curve(data, title=None): - """Creates a ROC plot. - - This function creates and returns a Matplotlib figure with a - ROC plot. - - - Parameters - ---------- - - data : dict - A dictionary in which keys are strings defining plot labels and values - are dictionaries with two entries: - - * ``df``: :py:class:`pandas.DataFrame` - - A dataframe that is produced by our predictor engine containing - the following columns: ``filename``, ``likelihood``, - ``ground_truth``. - - * ``threshold``: :py:class:`list` - - A threshold for each set. Not used here. - - - Returns - ------- - - figure : matplotlib.figure.Figure - A matplotlib figure you can save or display (uses an ``agg`` backend) - """ - fig, axes = plt.subplots(1) - - # Names and bounds - axes.set_xlabel("1 - specificity") - axes.set_ylabel("Sensitivity") - axes.set_xlim([0.0, 1.0]) - axes.set_ylim([0.0, 1.0]) - - # we should see some of axes 1 axes - axes.spines["right"].set_visible(False) - axes.spines["top"].set_visible(False) - axes.spines["left"].set_position(("data", -0.015)) - axes.spines["bottom"].set_position(("data", -0.015)) - - if title is not None: - axes.set_title(title) - - axes.grid(linestyle="--", linewidth=1, color="gray", alpha=0.2) - - plt.tight_layout() - - lines = ["-", "--", "-.", ":"] - colors = [ - "#1f77b4", - "#ff7f0e", - "#2ca02c", - "#d62728", - "#9467bd", - "#8c564b", - "#e377c2", - "#7f7f7f", - "#bcbd22", - "#17becf", - ] - colorcycler = cycle(colors) - linecycler = cycle(lines) - - legend = [] - - for name, value in data.items(): - df = value["df"] - - # plots roc curve - fpr, tpr, _ = r_curve(df["ground_truth"], df["likelihood"]) - _auc = auc(fpr, tpr) - label = f"{name} (AUC={_auc:.2f})" - color = next(colorcycler) - style = next(linecycler) - - (line,) = axes.plot(fpr, tpr, color=color, linestyle=style) - legend.append((line, label)) - - if len(label) > 1: - axes.legend( - [k[0] for k in legend], - [k[1] for k in legend], - loc="lower right", - fancybox=True, - framealpha=0.7, - ) - - return fig - - -def relevance_analysis_plot(data, title=None): - """Create an histogram plot to show the relative importance of features. - - Parameters - ---------- - - data : :py:class:`list` - The list of values (one for each feature) - - - Returns - ------- - - figure : matplotlib.figure.Figure - A matplotlib figure you can save or display (uses an ``agg`` backend) - """ - fig, axes = plt.subplots(1, 1, figsize=(6, 6)) - - # Names and bounds - axes.set_xlabel("Features") - axes.set_ylabel("Importance") - - # we should see some of axes 1 axes - axes.spines["right"].set_visible(False) - axes.spines["top"].set_visible(False) - - if title is not None: - axes.set_title(title) - - # 818C2E = likely - # F2921D = could be - # 8C3503 = unlikely - - labels = [ - "Cardiomegaly", - "Emphysema", - "Pleural effusion", - "Hernia", - "Infiltration", - "Mass", - "Nodule", - "Atelectasis", - "Pneumothorax", - "Pleural thickening", - "Pneumonia", - "Fibrosis", - "Edema", - "Consolidation", - ] - bars = axes.bar(labels, data, color="#8C3503") - - bars[2].set_color("#818C2E") - bars[4].set_color("#818C2E") - bars[10].set_color("#818C2E") - bars[5].set_color("#F2921D") - bars[6].set_color("#F2921D") - bars[7].set_color("#F2921D") - bars[11].set_color("#F2921D") - bars[13].set_color("#F2921D") - - for tick in axes.get_xticklabels(): - tick.set_rotation(90) - - fig.tight_layout() - - return fig diff --git a/src/ptbench/utils/table.py b/src/ptbench/utils/table.py deleted file mode 100644 index c9d35988978df6f4a9a7721e399c27be1d2e8d68..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/table.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import tabulate -import torch - -from sklearn.metrics import auc -from sklearn.metrics import precision_recall_curve as pr_curve -from sklearn.metrics import roc_curve as r_curve - -from ..engine.evaluator import posneg -from ..utils.measure import base_measures, bayesian_measures - - -def performance_table(data, fmt): - """Tables result comparison in a given format. - - Parameters - ---------- - - data : dict - A dictionary in which keys are strings defining plot labels and values - are dictionaries with two entries: - - * ``df``: :py:class:`pandas.DataFrame` - - A dataframe that is produced by our predictor engine containing - the following columns: ``filename``, ``likelihood``, - ``ground_truth``. - - * ``threshold``: :py:class:`list` - - A threshold to compute measures. - - - fmt : str - One of the formats supported by tabulate. - - - Returns - ------- - - table : str - A table in a specific format - """ - headers = [ - "Dataset", - "T", - "T Type", - "F1 (95% CI)", - "Prec (95% CI)", - "Recall/Sen (95% CI)", - "Spec (95% CI)", - "Acc (95% CI)", - "AUC (PRC)", - "AUC (ROC)", - ] - - table = [] - for k, v in data.items(): - entry = [ - k, - v["threshold"], - v["threshold_type"], - ] - - df = v["df"] - - gt = torch.tensor(df["ground_truth"].values) - pred = torch.tensor(df["likelihood"].values) - threshold = v["threshold"] - - tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold) - - # calc measures from scalars - tp_count = torch.sum(tp_tensor).item() - fp_count = torch.sum(fp_tensor).item() - tn_count = torch.sum(tn_tensor).item() - fn_count = torch.sum(fn_tensor).item() - - base_m = base_measures( - tp_count, - fp_count, - tn_count, - fn_count, - ) - - bayes_m = bayesian_measures( - tp_count, - fp_count, - tn_count, - fn_count, - lambda_=1, - coverage=0.95, - ) - - # statistics based on the "assigned" threshold (a priori, less biased) - entry.append( - "{:.2f} ({:.2f}, {:.2f})".format( - base_m[5], bayes_m[5][2], bayes_m[5][3] - ) - ) # f1 - entry.append( - "{:.2f} ({:.2f}, {:.2f})".format( - base_m[0], bayes_m[0][2], bayes_m[0][3] - ) - ) # precision - entry.append( - "{:.2f} ({:.2f}, {:.2f})".format( - base_m[1], bayes_m[1][2], bayes_m[1][3] - ) - ) # recall/sensitivity - entry.append( - "{:.2f} ({:.2f}, {:.2f})".format( - base_m[2], bayes_m[2][2], bayes_m[2][3] - ) - ) # specificity - entry.append( - "{:.2f} ({:.2f}, {:.2f})".format( - base_m[3], bayes_m[3][2], bayes_m[3][3] - ) - ) # accuracy - - prec, recall, _ = pr_curve(gt, pred) - fpr, tpr, _ = r_curve(gt, pred) - - entry.append(auc(recall, prec)) - entry.append(auc(fpr, tpr)) - - table.append(entry) - - return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")