Commit 17628ba9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Improve accuracy of ROC, DET, and PR plots

PR - > precision_recall
Internally computes thresholds using both far and frr values for more
accurate plots.
Fixes #60
parent b57f5eba
Pipeline #31334 passed with stage
in 10 minutes and 1 second
......@@ -8,6 +8,7 @@ from .version import module as __version__
from . import plot
from . import calibration
from . import load
import numpy
......
......@@ -238,31 +238,80 @@ double bob::measure::minWeightedErrorRateThreshold(
return bob::measure::minimizingThreshold(neg, pos, predicate);
}
blitz::Array<double, 1>
bob::measure::log_values(size_t points_, int min_power) {
int points = (int)points_;
blitz::Array<double, 1> retval(points);
double counts_per_step = points / (-min_power) ;
for (int i = 1-points; i <= 0; ++i) {
retval(i+points-1) = std::pow(10., (double)i/counts_per_step);
}
return retval;
}
blitz::Array<double, 1>
bob::measure::meaningfulThresholds(
const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
size_t points_, int min_far, bool is_sorted) {
int points = (int)points_;
int half_points = points / 2;
blitz::Array<double, 1> thresholds(points);
blitz::Array<double, 1> neg, pos;
// sort negatives and positives
sort(negatives, neg, is_sorted);
sort(positives, pos, is_sorted);
// Create an far_list and frr_list
auto frr_list = bob::measure::log_values(half_points, min_far);
auto far_list = bob::measure::log_values(points - half_points, min_far);
// Compute thresholds for far_list and frr_list
for (int i = 0; i < points; ++i) {
if (i < half_points)
thresholds(i) = bob::measure::frrThreshold(neg, pos, frr_list(i), true);
else
thresholds(i) = bob::measure::farThreshold(neg, pos, far_list(i-half_points), true);
}
// Sort the thresholds
bob::core::array::sort(thresholds);
return thresholds;
}
blitz::Array<double, 2>
bob::measure::roc(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives, size_t points) {
// Uses roc_for_far internally
// Create an far_list
blitz::Array<double, 1> far_list((int)points);
int min_far = -8; // minimum FAR in terms of 10^(min_far)
double counts_per_step = points / (-min_far) ;
for (int i = 1-(int)points; i <= 0; ++i) {
far_list(i+(int)points-1) = std::pow(10., (double)i/counts_per_step);
const blitz::Array<double, 1> &positives,
size_t points_, int min_far) {
int points = (int)points_;
blitz::Array<double, 2> retval(2, points);
auto thresholds = bob::measure::meaningfulThresholds(
negatives, positives, points_, min_far);
// compute far and frr based on these thresholds
for (int i = 0; i < points; ++i) {
auto temp = bob::measure::farfrr(negatives, positives, thresholds(i));
retval(0, i) = temp.first;
retval(1, i) = temp.second;
}
return bob::measure::roc_for_far(negatives, positives, far_list, false);
return retval;
}
blitz::Array<double, 2>
bob::measure::precision_recall_curve(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
size_t points) {
double min = std::min(blitz::min(negatives), blitz::min(positives));
double max = std::max(blitz::max(negatives), blitz::max(positives));
double step = (max - min) / ((double)points - 1.0);
blitz::Array<double, 2> retval(2, points);
auto thresholds = bob::measure::meaningfulThresholds(
negatives, positives, points);
for (int i = 0; i < (int)points; ++i) {
std::pair<double, double> ratios =
bob::measure::precision_recall(negatives, positives, min + i * step);
auto ratios = bob::measure::precision_recall(negatives, positives, thresholds(i));
retval(0, i) = ratios.first;
retval(1, i) = ratios.second;
}
......@@ -441,6 +490,54 @@ bob::measure::roc_for_far(const blitz::Array<double, 1> &negatives,
return retval;
}
/**
* This function computes the ROC coordinates for the given positive and
* negative values at the given FNR (FRR) positions.
*
* @param negatives Impostor scores
* @param positives Client scores
* @param frr_list The list of FNR (FRR) values where the FNR (FRR) should be calculated
*
* @return The ROC curve with the FPR in the first row and the FNR in the
* second.
*/
blitz::Array<double, 2>
bob::measure::roc_for_frr(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
const blitz::Array<double, 1> &frr_list,
bool is_sorted) {
int n_points = frr_list.extent(0);
if (negatives.extent(0) == 0)
throw std::runtime_error("The given set of negatives is empty.");
if (positives.extent(0) == 0)
throw std::runtime_error("The given set of positives is empty.");
// sort negative and positive scores ascendantly
blitz::Array<double, 1> neg, pos;
sort(negatives, neg, is_sorted);
sort(positives, pos, is_sorted);
blitz::Array<double, 2> retval(2, n_points);
// index into the FRR list
int frr_index = n_points - 1;
// Get the threshold for the requested far values and calculate far and frr
// values based on the threshold.
while(frr_index >= 0) {
// calculate the threshold for the requested frr
auto threshold = bob::measure::frrThreshold(neg, pos, frr_list(frr_index), true);
// calculate the far and re-calculate the frr
auto farfrr = bob::measure::farfrr(neg, pos, threshold);
retval(0, frr_index) = farfrr.first;
retval(1, frr_index) = farfrr.second;
frr_index--;
}
return retval;
}
/**
* The input to this function is a cumulative probability. The output from
* this function is the Normal deviate that corresponds to that probability.
......@@ -511,9 +608,9 @@ double bob::measure::ppndf(double value) { return _ppndf(value); }
blitz::Array<double, 2>
bob::measure::det(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives, size_t points) {
const blitz::Array<double, 1> &positives, size_t points, int min_far) {
blitz::Array<double, 2> retval(2, points);
retval = blitz::_ppndf(bob::measure::roc(negatives, positives, points));
retval = blitz::_ppndf(bob::measure::roc(negatives, positives, points, min_far));
return retval;
}
......
......@@ -314,18 +314,43 @@ double farThreshold(const blitz::Array<double, 1> &negatives,
double frrThreshold(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives, double frr_value,
bool is_sorted = false);
/**
* Computes log-scaled values between :math:`10^{min_power}` and 1
*/
blitz::Array<double, 1> log_values(size_t points, int min_power);
/**
* This function creates a list of far (and frr) values that we are interested
* to see on the curve. Computes thresholds for those points. Sorts the
* thresholds so we get sorted numbers to plot on the curve and returns the
* thresholds. Some points will be duplicate but in terms of resolution and
* accuracy this is better than just changing the threshold from min of scores
* to max of scores with equal spaces.
*/
blitz::Array<double, 1> meaningfulThresholds(
const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
size_t points_, int min_far = -8, bool is_sorted = false);
/**
* Calculates the ROC curve given a set of positive and negative scores and a
* number of desired points. Returns a two-dimensional blitz::Array of
* doubles that express the X (FNR) and Y (FPR) coordinates in this order.
* The points in which the ROC curve are calculated are distributed
* uniformly in the range [min(negatives, positives), max(negatives,
* positives)].
* doubles that express the X (FPR) and Y (FNR) coordinates in this order.
* Internally it uses roc_for_far and roc_for_frr to compute X and Y. This will make
* sure all corner cases are addressed. It is recommended to use a very large number of
* points (say 2000) to get a smooth output.
*
* @param negatives The impostor scores
* @param positives The genuine scores
* @param points The number of points in of X and Y. It should be multiples of 2.
* @param min_far Minimum FAR in terms of 10^(min_far). This value is also used for
* min_frr.
*
* @return A two dimensional array with shape of <2, 2*(points//2)>
*/
blitz::Array<double, 2> roc(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
size_t points);
size_t points, int min_far = -8);
/**
* Calculates the precision-recall curve given a set of positive and negative
......@@ -374,6 +399,16 @@ blitz::Array<double, 2> roc_for_far(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &far_list,
bool is_sorted = false);
/**
* Calculates the ROC curve given a set of positive and negative scores at
* the given FNR coordinates. Returns a two-dimensional blitz::Array of
* doubles that express the X (FPR) and Y (CPR) coordinates in this order.
*/
blitz::Array<double, 2> roc_for_frr(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
const blitz::Array<double, 1> &frr_list,
bool is_sorted = false);
/**
* Returns the Deviate Scale equivalent of a false rejection/acceptance
* ratio.
......@@ -406,7 +441,7 @@ double ppndf(double value);
*/
blitz::Array<double, 2> det(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
size_t points);
size_t points, int min_far = -8);
/**
* Calculates the EPC curve given a set of positive and negative scores and a
......
......@@ -151,13 +151,17 @@ static auto det_doc =
"recommended by NIST. "
"The derivative scales are computed with the "
":py:func:`bob.measure.ppndf` function.")
.add_prototype("negatives, positives, n_points", "curve")
.add_prototype("negatives, positives, n_points, [min_far]", "curve")
.add_parameter(
"negatives, positives", "array_like(1D, float)",
"The list of negative and positive scores to compute the DET for")
.add_parameter("n_points", "int", "The number of points on the DET "
"curve, for which the DET should be "
"evaluated")
.add_parameter("min_far", "int", "Minimum FAR in terms of 10^(min_far). "
"This value is also used for min_frr. "
"Default value is -8. Values should be "
"negative.")
.add_return("curve", "array_like(2D, float)",
"The DET curve, with the FPR in the first and the FNR in "
"the second row");
......@@ -168,10 +172,11 @@ static PyObject *det(PyObject *, PyObject *args, PyObject *kwds) {
PyBlitzArrayObject *neg;
PyBlitzArrayObject *pos;
Py_ssize_t n_points;
int min_far = -8;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&n", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&n|i", kwlist,
&double1d_converter, &neg,
&double1d_converter, &pos, &n_points))
&double1d_converter, &pos, &n_points, &min_far))
return 0;
// protects acquired resources through this scope
......@@ -180,7 +185,7 @@ static PyObject *det(PyObject *, PyObject *args, PyObject *kwds) {
auto result =
bob::measure::det(*PyBlitzArrayCxx_AsBlitz<double, 1>(neg),
*PyBlitzArrayCxx_AsBlitz<double, 1>(pos), n_points);
*PyBlitzArrayCxx_AsBlitz<double, 1>(pos), n_points, min_far);
return PyBlitzArrayCxx_AsNumpy(result);
BOB_CATCH_FUNCTION("det", 0)
......@@ -218,7 +223,7 @@ static auto roc_doc =
"Calculates points of an Receiver Operating Characteristic (ROC)",
"Calculates the ROC curve given a set of negative and positive scores "
"and a desired number of points. ")
.add_prototype("negatives, positives, n_points", "curve")
.add_prototype("negatives, positives, n_points, [min_far]", "curve")
.add_parameter("negatives, positives", "array_like(1D, float)",
"The negative and positive scores, for which the ROC "
"curve should be calculated")
......@@ -227,6 +232,10 @@ static auto roc_doc =
"distributed uniformly in the range "
"``[min(negatives, positives), "
"max(negatives, positives)]``")
.add_parameter("min_far", "int", "Minimum FAR in terms of 10^(min_far). "
"This value is also used for min_frr. "
"Default value is -8. Values should be "
"negative.")
.add_return("curve", "array_like(2D, float)",
"A two-dimensional array of doubles that express the X "
"(FPR) and Y (FNR) coordinates in this order");
......@@ -237,10 +246,12 @@ static PyObject *roc(PyObject *, PyObject *args, PyObject *kwds) {
PyBlitzArrayObject *neg;
PyBlitzArrayObject *pos;
Py_ssize_t n_points;
int min_far = -8;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&n", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&n|i", kwlist,
&double1d_converter, &neg,
&double1d_converter, &pos, &n_points))
&double1d_converter, &pos, &n_points, &min_far))
return 0;
// protects acquired resources through this scope
......@@ -249,7 +260,7 @@ static PyObject *roc(PyObject *, PyObject *args, PyObject *kwds) {
auto result =
bob::measure::roc(*PyBlitzArrayCxx_AsBlitz<double, 1>(neg),
*PyBlitzArrayCxx_AsBlitz<double, 1>(pos), n_points);
*PyBlitzArrayCxx_AsBlitz<double, 1>(pos), n_points, min_far);
return PyBlitzArrayCxx_AsNumpy(result);
BOB_CATCH_FUNCTION("roc", 0)
......
......@@ -47,7 +47,7 @@ def _semilogx(x, y, **kwargs):
return pyplot.semilogx(x, y, **kwargs)
def roc(negatives, positives, npoints=100, CAR=False, **kwargs):
def roc(negatives, positives, npoints=2000, CAR=False, min_far=-8, **kwargs):
"""Plots Receiver Operating Characteristic (ROC) curve.
This method will call ``matplotlib`` to plot the ROC curve for a system which
......@@ -98,7 +98,7 @@ def roc(negatives, positives, npoints=100, CAR=False, **kwargs):
from matplotlib import pyplot
from . import roc as calc
out = calc(negatives, positives, npoints)
out = calc(negatives, positives, npoints, min_far)
if not CAR:
return pyplot.plot(out[0, :], out[1, :], **kwargs)
else:
......@@ -164,7 +164,7 @@ def roc_for_far(negatives, positives, far_values=log_values(), CAR=True,
return _semilogx(out[0, :], (1 - out[1, :]), **kwargs)
def precision_recall_curve(negatives, positives, npoints=100, **kwargs):
def precision_recall_curve(negatives, positives, npoints=2000, **kwargs):
"""Plots a Precision-Recall curve.
This method will call ``matplotlib`` to plot the precision-recall curve for a
......@@ -279,7 +279,7 @@ def epc(dev_negatives, dev_positives, test_negatives, test_positives,
return pyplot.plot(out[0, :], 100.0 * out[1, :], **kwargs)
def det(negatives, positives, npoints=100, **kwargs):
def det(negatives, positives, npoints=2000, min_far=-8, **kwargs):
"""Plots Detection Error Trade-off (DET) curve as defined in the paper:
Martin, A., Doddington, G., Kamm, T., Ordowski, M., & Przybocki, M. (1997).
......@@ -395,7 +395,7 @@ def det(negatives, positives, npoints=100, **kwargs):
from . import det as calc
from . import ppndf
out = calc(negatives, positives, npoints)
out = calc(negatives, positives, npoints, min_far)
retval = pyplot.plot(out[0, :], out[1, :], **kwargs)
# now the trick: we must plot the tick marks by hand using the PPNDF method
......
......@@ -246,7 +246,7 @@ def points_curve_option(**kwargs):
ctx.meta['points'] = value
return value
return click.option(
'-n', '--points', type=INT, default=100, show_default=True,
'-n', '--points', type=INT, default=2000, show_default=True,
help='The number of points use to draw curves in plots',
callback=callback, **kwargs)(func)
return custom_points_curve_option
......
......@@ -427,7 +427,7 @@ class PlotBase(MeasureBase):
def __init__(self, ctx, scores, evaluation, func_load):
super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
self._output = ctx.meta.get('output')
self._points = ctx.meta.get('points', 100)
self._points = ctx.meta.get('points', 2000)
self._split = ctx.meta.get('split')
self._axlim = ctx.meta.get('axlim')
self._disp_legend = ctx.meta.get('disp_legend', True)
......@@ -550,6 +550,7 @@ class Roc(PlotBase):
# custom defaults
if self._axlim is None:
self._axlim = [None, None, -0.05, 1.05]
self._min_dig = -4 if self._min_dig is None else self._min_dig
def compute(self, idx, input_scores, input_names):
''' Plot ROC for dev and eval data using
......@@ -564,10 +565,11 @@ class Roc(PlotBase):
mpl.figure(1)
if self._eval:
LOGGER.info("ROC dev. curve using %s", dev_file)
plot.roc_for_far(
plot.roc(
dev_neg, dev_pos,
far_values=plot.log_values(self._min_dig or -4),
npoints=self._points,
CAR=self._semilogx,
min_far=self._min_dig,
color=self._colors[idx], linestyle=self._linestyles[idx],
label=self._label('dev', idx)
)
......@@ -576,10 +578,11 @@ class Roc(PlotBase):
linestyle = '--' if not self._split else self._linestyles[idx]
LOGGER.info("ROC eval. curve using %s", eval_file)
plot.roc_for_far(
plot.roc(
eval_neg, eval_pos, linestyle=linestyle,
far_values=plot.log_values(self._min_dig or -4),
npoints=self._points,
CAR=self._semilogx,
min_far=self._min_dig,
color=self._colors[idx],
label=self._label('eval.', idx)
)
......@@ -595,10 +598,11 @@ class Roc(PlotBase):
self._eval_points[line].append((eval_fmr, eval_fnmr))
else:
LOGGER.info("ROC dev. curve using %s", dev_file)
plot.roc_for_far(
plot.roc(
dev_neg, dev_pos,
far_values=plot.log_values(self._min_dig or -4),
npoints=self._points,
CAR=self._semilogx,
min_far=self._min_dig,
color=self._colors[idx], linestyle=self._linestyles[idx],
label=self._label('dev', idx)
)
......@@ -625,6 +629,8 @@ class Det(PlotBase):
if self._min_dig is not None:
self._axlim[0] = math.pow(10, self._min_dig) * 100
self._min_dig = -4 if self._min_dig is None else self._min_dig
def compute(self, idx, input_scores, input_names):
''' Plot DET for dev and eval data using
:py:func:`bob.measure.plot.det`'''
......@@ -639,7 +645,8 @@ class Det(PlotBase):
if self._eval and eval_neg is not None:
LOGGER.info("DET dev. curve using %s", dev_file)
plot.det(
dev_neg, dev_pos, self._points, color=self._colors[idx],
dev_neg, dev_pos, self._points, min_far=self._min_dig,
color=self._colors[idx],
linestyle=self._linestyles[idx],
label=self._label('dev.', idx)
)
......@@ -648,7 +655,8 @@ class Det(PlotBase):
linestyle = '--' if not self._split else self._linestyles[idx]
LOGGER.info("DET eval. curve using %s", eval_file)
plot.det(
eval_neg, eval_pos, self._points, color=self._colors[idx],
eval_neg, eval_pos, self._points, min_far=self._min_dig,
color=self._colors[idx],
linestyle=linestyle,
label=self._label('eval.', idx)
)
......@@ -664,7 +672,8 @@ class Det(PlotBase):
else:
LOGGER.info("DET dev. curve using %s", dev_file)
plot.det(
dev_neg, dev_pos, self._points, color=self._colors[idx],
dev_neg, dev_pos, self._points, min_far=self._min_dig,
color=self._colors[idx],
linestyle=self._linestyles[idx],
label=self._label('dev.', idx)
)
......
......@@ -492,8 +492,7 @@ def test_open_set_rates():
def test_mindcf():
""" Test outlier scores in negative set
"""
# Test outlier scores in negative set
from bob.measure import min_weighted_error_rate_threshold, farfrr
cost = 0.99
negatives = [-3, -2, -1, -0.5, 4]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment