# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import tabulate import torch from sklearn.metrics import auc from sklearn.metrics import precision_recall_curve as pr_curve from sklearn.metrics import roc_curve as r_curve from ..engine.evaluator import posneg from ..utils.measure import base_measures, bayesian_measures def performance_table(data, fmt): """Tables result comparison in a given format. Parameters ---------- data : dict A dictionary in which keys are strings defining plot labels and values are dictionaries with two entries: * ``df``: :py:class:`pandas.DataFrame` A dataframe that is produced by our predictor engine containing the following columns: ``filename``, ``likelihood``, ``ground_truth``. * ``threshold``: :py:class:`list` A threshold to compute measures. fmt : str One of the formats supported by tabulate. Returns ------- table : str A table in a specific format """ headers = [ "Dataset", "T", "T Type", "F1 (95% CI)", "Prec (95% CI)", "Recall/Sen (95% CI)", "Spec (95% CI)", "Acc (95% CI)", "AUC (PRC)", "AUC (ROC)", ] table = [] for k, v in data.items(): entry = [ k, v["threshold"], v["threshold_type"], ] df = v["df"] gt = torch.tensor(df["ground_truth"].values) pred = torch.tensor(df["likelihood"].values) threshold = v["threshold"] tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold) # calc measures from scalars tp_count = torch.sum(tp_tensor).item() fp_count = torch.sum(fp_tensor).item() tn_count = torch.sum(tn_tensor).item() fn_count = torch.sum(fn_tensor).item() base_m = base_measures( tp_count, fp_count, tn_count, fn_count, ) bayes_m = bayesian_measures( tp_count, fp_count, tn_count, fn_count, lambda_=1, coverage=0.95, ) # statistics based on the "assigned" threshold (a priori, less biased) entry.append( "{:.2f} ({:.2f}, {:.2f})".format( base_m[5], bayes_m[5][2], bayes_m[5][3] ) ) # f1 entry.append( "{:.2f} ({:.2f}, {:.2f})".format( base_m[0], bayes_m[0][2], bayes_m[0][3] ) ) # precision entry.append( "{:.2f} ({:.2f}, {:.2f})".format( base_m[1], bayes_m[1][2], bayes_m[1][3] ) ) # recall/sensitivity entry.append( "{:.2f} ({:.2f}, {:.2f})".format( base_m[2], bayes_m[2][2], bayes_m[2][3] ) ) # specificity entry.append( "{:.2f} ({:.2f}, {:.2f})".format( base_m[3], bayes_m[3][2], bayes_m[3][3] ) ) # accuracy prec, recall, _ = pr_curve(gt, pred) fpr, tpr, _ = r_curve(gt, pred) entry.append(auc(recall, prec)) entry.append(auc(fpr, tpr)) table.append(entry) return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")