From 5e3dbf1ef613af79dc7f5c333a031c1c96ce580f Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 5 May 2020 21:25:45 +0200
Subject: [PATCH] [utils.metric] Implement proper AUC for our RecallxPrecision
 curves

---
 bob/ip/binseg/utils/metric.py | 51 +++++++++++++++++++++++++++--------
 1 file changed, 40 insertions(+), 11 deletions(-)

diff --git a/bob/ip/binseg/utils/metric.py b/bob/ip/binseg/utils/metric.py
index 903836f6..b49f4ede 100644
--- a/bob/ip/binseg/utils/metric.py
+++ b/bob/ip/binseg/utils/metric.py
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 
 from collections import deque
+import numpy
 import torch
 
 
@@ -64,24 +65,52 @@ def base_metrics(tp, fp, tn, fn):
     return [precision, recall, specificity, accuracy, jaccard, f1_score]
 
 
-def auc(precision, recall):
+def auc(x, y):
     """Calculates the area under the precision-recall curve (AUC)
 
-    .. todo:: Integrate this to metrics reporting in compare.py
-    """
+    This function requires a minimum of 2 points and will use the trapezoidal
+    method to calculate the area under a curve bound between ``[0.0, 1.0]``.
+    It interpolates missing points if required.  The input ``x`` should be
+    continuously increasing or decreasing.
+
 
-    rec_unique, rec_unique_ndx = numpy.unique(recall, return_index=True)
+    Parameters
+    ----------
 
-    prec_unique = precision[rec_unique_ndx]
+    x : numpy.ndarray
+        A 1D numpy array containing continuously increasing or decreasing
+        values for the X coordinate.
+
+    y : numpy.ndarray
+        A 1D numpy array containing the Y coordinates of the X values provided
+        in ``x``.
+
+    """
 
-    if rec_unique.shape[0] > 1:
-        prec_interp = numpy.interp(
-            numpy.arange(0, 1, 0.01),
-            rec_unique,
-            prec_unique,
+    assert len(x) == len(y)
+
+    dx = numpy.diff(x)
+    if numpy.any(dx < 0):
+        if numpy.all(dx <= 0):
+            # invert direction
+            x = x[::-1]
+            y = y[::-1]
+        else:
+            raise ValueError("x is neither increasing nor decreasing "
+                             ": {}.".format(x))
+
+    # avoids repeated sums for every y
+    y_unique, y_unique_ndx = numpy.unique(y, return_index=True)
+    x_unique = x[y_unique_ndx]
+
+    if y_unique.shape[0] > 1:
+        x_interp = numpy.interp(
+            numpy.arange(0, 1, 0.001),
+            y_unique,
+            x_unique,
             left=0.0,
             right=0.0,
         )
-        return prec_interp.sum() * 0.01
+        return x_interp.sum() * 0.001
 
     return 0.0
-- 
GitLab