Skip to content
Snippets Groups Projects
Commit 5e3dbf1e authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[utils.metric] Implement proper AUC for our RecallxPrecision curves

parent 10640c44
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from collections import deque from collections import deque
import numpy
import torch import torch
...@@ -64,24 +65,52 @@ def base_metrics(tp, fp, tn, fn): ...@@ -64,24 +65,52 @@ def base_metrics(tp, fp, tn, fn):
return [precision, recall, specificity, accuracy, jaccard, f1_score] return [precision, recall, specificity, accuracy, jaccard, f1_score]
def auc(precision, recall): def auc(x, y):
"""Calculates the area under the precision-recall curve (AUC) """Calculates the area under the precision-recall curve (AUC)
.. todo:: Integrate this to metrics reporting in compare.py This function requires a minimum of 2 points and will use the trapezoidal
""" method to calculate the area under a curve bound between ``[0.0, 1.0]``.
It interpolates missing points if required. The input ``x`` should be
continuously increasing or decreasing.
rec_unique, rec_unique_ndx = numpy.unique(recall, return_index=True) Parameters
----------
prec_unique = precision[rec_unique_ndx] x : numpy.ndarray
A 1D numpy array containing continuously increasing or decreasing
values for the X coordinate.
y : numpy.ndarray
A 1D numpy array containing the Y coordinates of the X values provided
in ``x``.
"""
if rec_unique.shape[0] > 1: assert len(x) == len(y)
prec_interp = numpy.interp(
numpy.arange(0, 1, 0.01), dx = numpy.diff(x)
rec_unique, if numpy.any(dx < 0):
prec_unique, if numpy.all(dx <= 0):
# invert direction
x = x[::-1]
y = y[::-1]
else:
raise ValueError("x is neither increasing nor decreasing "
": {}.".format(x))
# avoids repeated sums for every y
y_unique, y_unique_ndx = numpy.unique(y, return_index=True)
x_unique = x[y_unique_ndx]
if y_unique.shape[0] > 1:
x_interp = numpy.interp(
numpy.arange(0, 1, 0.001),
y_unique,
x_unique,
left=0.0, left=0.0,
right=0.0, right=0.0,
) )
return prec_interp.sum() * 0.01 return x_interp.sum() * 0.001
return 0.0 return 0.0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment