Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import tabulate
import torch
from sklearn.metrics import auc
from sklearn.metrics import precision_recall_curve as pr_curve
from sklearn.metrics import roc_curve as r_curve
from ..engine.evaluator import posneg
from ..utils.measure import base_measures, bayesian_measures
def performance_table(data, fmt):
"""Tables result comparison in a given format.
Parameters
----------
data : dict
A dictionary in which keys are strings defining plot labels and values
are dictionaries with two entries:
* ``df``: :py:class:`pandas.DataFrame`
A dataframe that is produced by our predictor engine containing
the following columns: ``filename``, ``likelihood``,
``ground_truth``.
* ``threshold``: :py:class:`list`
A threshold to compute measures.
fmt : str
One of the formats supported by tabulate.
Returns
-------
table : str
A table in a specific format
"""
headers = [
"Dataset",
"T",
"T Type",
"F1 (95% CI)",
"Prec (95% CI)",
"Recall/Sen (95% CI)",
"Spec (95% CI)",
"Acc (95% CI)",
"AUC (PRC)",
"AUC (ROC)",
]
table = []
for k, v in data.items():
entry = [
k,
v["threshold"],
v["threshold_type"],
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
]
df = v["df"]
gt = torch.tensor(df["ground_truth"].values)
pred = torch.tensor(df["likelihood"].values)
threshold = v["threshold"]
tp_tensor, fp_tensor, tn_tensor, fn_tensor = posneg(pred, gt, threshold)
# calc measures from scalars
tp_count = torch.sum(tp_tensor).item()
fp_count = torch.sum(fp_tensor).item()
tn_count = torch.sum(tn_tensor).item()
fn_count = torch.sum(fn_tensor).item()
base_m = base_measures(
tp_count,
fp_count,
tn_count,
fn_count,
)
bayes_m = bayesian_measures(
tp_count,
fp_count,
tn_count,
fn_count,
lambda_=1,
coverage=0.95,
)
# statistics based on the "assigned" threshold (a priori, less biased)
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[5], bayes_m[5][2], bayes_m[5][3]
)
) # f1
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[0], bayes_m[0][2], bayes_m[0][3]
)
) # precision
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[1], bayes_m[1][2], bayes_m[1][3]
)
) # recall/sensitivity
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[2], bayes_m[2][2], bayes_m[2][3]
)
) # specificity
entry.append(
"{:.2f} ({:.2f}, {:.2f})".format(
base_m[3], bayes_m[3][2], bayes_m[3][3]
)
) # accuracy
prec, recall, _ = pr_curve(gt, pred)
fpr, tpr, _ = r_curve(gt, pred)
entry.append(auc(recall, prec))
entry.append(auc(fpr, tpr))
table.append(entry)
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")