diff --git a/bob/measure/__init__.py b/bob/measure/__init__.py index 4ea55fa1bac3e5b9be9a2c96b7252d39cd12b8d1..c61e710dfbe7fc659697d21e0df10020a1f3441d 100644 --- a/bob/measure/__init__.py +++ b/bob/measure/__init__.py @@ -474,6 +474,44 @@ def eer(negatives, positives, is_sorted=False, also_farfrr=False): return (far + frr) / 2.0 +def roc_auc_score(negatives, positives, npoints=2000, min_far=-8, log_scale=False): + """Area Under the ROC Curve. + Computes the area under the ROC curve. This is useful when you want to report one + number that represents an ROC curve. This implementation uses the trapezoidal rule for the integration of the ROC curve. For more information, see: + https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve + + + Parameters + ---------- + negatives : array_like + The negative scores. + positives : array_like + The positive scores. + npoints : int, optional + Number of points in the ROC curve. Higher numbers leads to more accurate ROC. + min_far : float, optional + Min FAR and FRR values to consider when calculating ROC. + log_scale : bool, optional + If True, converts the x axis (FPR) to log10 scale before calculating AUC. This is + useful in cases where len(negatives) >> len(positives) + + Returns + ------- + float + The ROC AUC. If ``log_scale`` is False, the value should be between 0 and 1. + """ + fpr, fnr = roc(negatives, positives, npoints, min_far=min_far) + tpr = 1 - fnr + + if log_scale: + fpr_pos = fpr > 0 + fpr, tpr = fpr[fpr_pos], tpr[fpr_pos] + fpr = numpy.log10(fpr) + + area = -1 * numpy.trapz(tpr, fpr) + return area + + def get_config(): """Returns a string containing the configuration information. """ diff --git a/bob/measure/script/commands.py b/bob/measure/script/commands.py index a20d6d947f1474257609068cb6f27f28e2117d02..18b553673fd8ac88f258e982e823ae4b227cdc17 100644 --- a/bob/measure/script/commands.py +++ b/bob/measure/script/commands.py @@ -13,7 +13,7 @@ CRITERIA = ('eer', 'min-hter', 'far') @common_options.metrics_command( common_options.METRICS_HELP.format( - names='FPR, FNR, precision, recall, F1-score', + names='FPR, FNR, precision, recall, F1-score, AUC ROC', criteria=CRITERIA, score_format=SCORE_FORMAT, hter_note=' ', command='bob measure metrics'), diff --git a/bob/measure/script/figure.py b/bob/measure/script/figure.py index aae0c04b652a5fea2b353cfff66e40c4fc0ce2a3..a4e5081c11c286984b1a8d1f5167fa000662ec6a 100644 --- a/bob/measure/script/figure.py +++ b/bob/measure/script/figure.py @@ -183,7 +183,7 @@ class Metrics(MeasureBase): def __init__(self, ctx, scores, evaluation, func_load, names=('False Positive Rate', 'False Negative Rate', - 'Precision', 'Recall', 'F1-score')): + 'Precision', 'Recall', 'F1-score', 'Area Under ROC Curve', 'Area Under ROC Curve (log scale)')): super(Metrics, self).__init__(ctx, scores, evaluation, func_load) self.names = names self._tablefmt = ctx.meta.get('tablefmt') @@ -209,7 +209,7 @@ class Metrics(MeasureBase): return utils.get_thres(criterion, dev_neg, dev_pos, far) def _numbers(self, neg, pos, threshold, fta): - from .. import (farfrr, precision_recall, f_score) + from .. import (farfrr, precision_recall, f_score, roc_auc_score) # fpr and fnr fmr, fnmr = farfrr(neg, pos, threshold) hter = (fmr + fnmr) / 2.0 @@ -226,8 +226,12 @@ class Metrics(MeasureBase): # f_score f1_score = f_score(neg, pos, threshold, 1) + + # AUC ROC + auc = roc_auc_score(neg, pos) + auc_log = roc_auc_score(neg, pos, log_scale=True) return (fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc, precision, - recall, f1_score) + recall, f1_score, auc, auc_log) def _strings(self, metrics): n_dec = '.%df' % self._decimal @@ -242,9 +246,11 @@ class Metrics(MeasureBase): prec_str = "%s" % format(metrics[10], n_dec) recall_str = "%s" % format(metrics[11], n_dec) f1_str = "%s" % format(metrics[12], n_dec) + auc_str = "%s" % format(metrics[13], n_dec) + auc_log_str = "%s" % format(metrics[14], n_dec) return (fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str, - prec_str, recall_str, f1_str) + prec_str, recall_str, f1_str, auc_str, auc_log_str) def _get_all_metrics(self, idx, input_scores, input_names): ''' Compute all metrics for dev and eval scores''' @@ -297,11 +303,15 @@ class Metrics(MeasureBase): LOGGER.warn("NaNs scores (%s) were found in %s amd removed", all_metrics[0][0], dev_file) headers = [' ' or title, 'Development'] - rows = [[self.names[0], all_metrics[0][1]], - [self.names[1], all_metrics[0][2]], - [self.names[2], all_metrics[0][6]], - [self.names[3], all_metrics[0][7]], - [self.names[4], all_metrics[0][8]]] + rows = [ + [self.names[0], all_metrics[0][1]], + [self.names[1], all_metrics[0][2]], + [self.names[2], all_metrics[0][6]], + [self.names[3], all_metrics[0][7]], + [self.names[4], all_metrics[0][8]], + [self.names[5], all_metrics[0][9]], + [self.names[6], all_metrics[0][10]], + ] if self._eval: eval_file = input_names[1] @@ -317,6 +327,8 @@ class Metrics(MeasureBase): rows[2].append(all_metrics[1][6]) rows[3].append(all_metrics[1][7]) rows[4].append(all_metrics[1][8]) + rows[5].append(all_metrics[1][9]) + rows[6].append(all_metrics[1][10]) click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file) diff --git a/bob/measure/test_error.py b/bob/measure/test_error.py index 8f294e070265cc51a4e699f729df3eb32d38e020..32d3a90ad231da6937893170cd64d31cfd32dc2f 100644 --- a/bob/measure/test_error.py +++ b/bob/measure/test_error.py @@ -503,3 +503,22 @@ def test_mindcf(): assert mindcf< 1.0 + 1e-8 +def test_roc_auc_score(): + from bob.measure import roc_auc_score + positives = bob.io.base.load(F('nonsep-positives.hdf5')) + negatives = bob.io.base.load(F('nonsep-negatives.hdf5')) + auc = roc_auc_score(negatives, positives) + + # commented out sklearn computation to avoid adding an extra test dependency + # from sklearn.metrics import roc_auc_score as oracle_auc + # y_true = numpy.concatenate([numpy.ones_like(positives), numpy.zeros_like(negatives)], axis=0) + # y_score = numpy.concatenate([positives, negatives], axis=0) + # oracle = oracle_auc(y_true, y_score) + oracle = 0.9326 + + assert numpy.allclose(auc, oracle), f"Expected {oracle} but got {auc} instead." + + # test the function on log scale as well + auc = roc_auc_score(negatives, positives, log_scale=True) + oracle = 1.4183699583300993 + assert numpy.allclose(auc, oracle), f"Expected {oracle} but got {auc} instead." diff --git a/bob/measure/utils.py b/bob/measure/utils.py index 1bfefb0c6a254e6cf61bee38432b687ecdfa5690..8cfcca35a66bd978c4216859b05061c7531a6dd8 100644 --- a/bob/measure/utils.py +++ b/bob/measure/utils.py @@ -115,7 +115,7 @@ def get_thres(criter, neg, pos, far=None): elif criter == 'far': if far is None: raise ValueError("FAR value must be provided through " - "``--far-value`` option.") + "``--far-value`` or ``--fpr-value`` option.") from . import far_threshold return far_threshold(neg, pos, far) else: diff --git a/doc/guide.rst b/doc/guide.rst index f8481c98d20505c96c00dcbd0dd4879a38f8d34e..1a93c5f6bace02ad35dde6945a30dd472e01682f 100644 --- a/doc/guide.rst +++ b/doc/guide.rst @@ -284,6 +284,9 @@ town. To plot an ROC curve, in possession of your **negatives** and >>> pyplot.ylabel('FNR (%)') # doctest: +SKIP >>> pyplot.grid(True) >>> pyplot.show() # doctest: +SKIP + >>> # You can also compute the area under the ROC curve: + >>> bob.measure.roc_auc_score(negatives, positives) + 0.8958 You should see an image like the following one: diff --git a/doc/nitpick-exceptions.txt b/doc/nitpick-exceptions.txt index ebab580e5f6e5ef3d980f00efe0e81bba3720d99..d41d8a73910aad36b9a0fa2efb2077ebf3fd77f1 100644 --- a/doc/nitpick-exceptions.txt +++ b/doc/nitpick-exceptions.txt @@ -1,2 +1,7 @@ # ignores stuff that does not exist in Python 2.7 manual py:class list +# ignores stuff that does not exist but makes sense +py:class array +py:class array_like +py:class optional +py:class callable diff --git a/doc/py_api.rst b/doc/py_api.rst index 45e0872b8f2c82519454ff0f71885e4c72703599..6ccdb6fba501d6d85ccbc70b517ec4a68607b116 100644 --- a/doc/py_api.rst +++ b/doc/py_api.rst @@ -49,6 +49,7 @@ Curves .. autosummary:: bob.measure.roc + bob.measure.roc_auc_score bob.measure.rocch bob.measure.roc_for_far bob.measure.det