Commit e752980d authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add log-scale variant of ROC AUC

parent d4808114
Pipeline #32910 passed with stage
in 25 minutes and 28 seconds
......@@ -474,7 +474,7 @@ def eer(negatives, positives, is_sorted=False, also_farfrr=False):
return (far + frr) / 2.0
def roc_auc_score(negatives, positives, npoints=2000, min_far=-8):
def roc_auc_score(negatives, positives, npoints=2000, min_far=-8, log_scale=False):
"""Area Under the ROC Curve.
Computes the area under the ROC curve. This is useful when you want to report one
number that represents an ROC curve. For more information, see:
......@@ -490,14 +490,23 @@ def roc_auc_score(negatives, positives, npoints=2000, min_far=-8):
Number of points in the ROC curve. Higher numbers leads to more accurate ROC.
min_far : float, optional
Min FAR and FRR values to consider when calculating ROC.
log_scale : bool, optional
If True, converts the x axis (FPR) to log10 scale before calculating AUC. This is
useful in cases where len(negatives) >> len(positives)
Returns
-------
float
The ROC AUC
The ROC AUC. If ``log_scale`` is False, the value should be between 0 and 1.
"""
fpr, fnr = roc(negatives, positives, npoints, min_far=min_far)
tpr = 1 - fnr
if log_scale:
fpr_pos = fpr > 0
fpr, tpr = fpr[fpr_pos], tpr[fpr_pos]
fpr = numpy.log10(fpr)
area = -1 * numpy.trapz(tpr, fpr)
return area
......
......@@ -183,7 +183,7 @@ class Metrics(MeasureBase):
def __init__(self, ctx, scores, evaluation, func_load,
names=('False Positive Rate', 'False Negative Rate',
'Precision', 'Recall', 'F1-score', 'Area Under ROC Curve')):
'Precision', 'Recall', 'F1-score', 'Area Under ROC Curve', 'Area Under ROC Curve (log scale)')):
super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
self.names = names
self._tablefmt = ctx.meta.get('tablefmt')
......@@ -229,8 +229,9 @@ class Metrics(MeasureBase):
# AUC ROC
auc = roc_auc_score(neg, pos)
auc_log = roc_auc_score(neg, pos, log_scale=True)
return (fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc, precision,
recall, f1_score, auc)
recall, f1_score, auc, auc_log)
def _strings(self, metrics):
n_dec = '.%df' % self._decimal
......@@ -246,9 +247,10 @@ class Metrics(MeasureBase):
recall_str = "%s" % format(metrics[11], n_dec)
f1_str = "%s" % format(metrics[12], n_dec)
auc_str = "%s" % format(metrics[13], n_dec)
auc_log_str = "%s" % format(metrics[14], n_dec)
return (fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str,
prec_str, recall_str, f1_str, auc_str)
prec_str, recall_str, f1_str, auc_str, auc_log_str)
def _get_all_metrics(self, idx, input_scores, input_names):
''' Compute all metrics for dev and eval scores'''
......@@ -308,6 +310,7 @@ class Metrics(MeasureBase):
[self.names[3], all_metrics[0][7]],
[self.names[4], all_metrics[0][8]],
[self.names[5], all_metrics[0][9]],
[self.names[6], all_metrics[0][10]],
]
if self._eval:
......@@ -325,6 +328,7 @@ class Metrics(MeasureBase):
rows[3].append(all_metrics[1][7])
rows[4].append(all_metrics[1][8])
rows[5].append(all_metrics[1][9])
rows[6].append(all_metrics[1][10])
click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
......
......@@ -517,3 +517,8 @@ def test_roc_auc_score():
oracle = 0.9326
assert numpy.allclose(auc, oracle), f"Expected {oracle} but got {auc} instead."
# test the function on log scale as well
auc = roc_auc_score(negatives, positives, log_scale=True)
oracle = 1.4183699583300993
assert numpy.allclose(auc, oracle), f"Expected {oracle} but got {auc} instead."
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment