Skip to content
Snippets Groups Projects

Improve threshold calculation in ROC, DET, and Precision and Recall plots

Merged Amir MOHAMMADI requested to merge roc-fixes into master
Files
13
+ 66
16
@@ -238,31 +238,80 @@ double bob::measure::minWeightedErrorRateThreshold(
return bob::measure::minimizingThreshold(neg, pos, predicate);
}
blitz::Array<double, 1>
bob::measure::log_values(size_t points_, int min_power) {
int points = (int)points_;
blitz::Array<double, 1> retval(points);
double counts_per_step = points / (-min_power) ;
for (int i = 1-points; i <= 0; ++i) {
retval(i+points-1) = std::pow(10., (double)i/counts_per_step);
}
return retval;
}
blitz::Array<double, 1>
bob::measure::meaningfulThresholds(
const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
size_t points_, int min_far, bool is_sorted) {
int points = (int)points_;
int half_points = points / 2;
blitz::Array<double, 1> thresholds(points);
blitz::Array<double, 1> neg, pos;
// sort negatives and positives
sort(negatives, neg, is_sorted);
sort(positives, pos, is_sorted);
// Create an far_list and frr_list
auto frr_list = bob::measure::log_values(half_points, min_far);
auto far_list = bob::measure::log_values(points - half_points, min_far);
// Compute thresholds for far_list and frr_list
for (int i = 0; i < points; ++i) {
if (i < half_points)
thresholds(i) = bob::measure::frrThreshold(neg, pos, frr_list(i), true);
else
thresholds(i) = bob::measure::farThreshold(neg, pos, far_list(i-half_points), true);
}
// Sort the thresholds
bob::core::array::sort(thresholds);
return thresholds;
}
blitz::Array<double, 2>
bob::measure::roc(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives, size_t points) {
// Uses roc_for_far internally
// Create an far_list
blitz::Array<double, 1> far_list((int)points);
int min_far = -8; // minimum FAR in terms of 10^(min_far)
double counts_per_step = points / (-min_far) ;
for (int i = 1-(int)points; i <= 0; ++i) {
far_list(i+(int)points-1) = std::pow(10., (double)i/counts_per_step);
const blitz::Array<double, 1> &positives,
size_t points_, int min_far) {
int points = (int)points_;
blitz::Array<double, 2> retval(2, points);
auto thresholds = bob::measure::meaningfulThresholds(
negatives, positives, points_, min_far);
// compute far and frr based on these thresholds
for (int i = 0; i < points; ++i) {
auto temp = bob::measure::farfrr(negatives, positives, thresholds(i));
retval(0, i) = temp.first;
retval(1, i) = temp.second;
}
return bob::measure::roc_for_far(negatives, positives, far_list, false);
return retval;
}
blitz::Array<double, 2>
bob::measure::precision_recall_curve(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives,
size_t points) {
double min = std::min(blitz::min(negatives), blitz::min(positives));
double max = std::max(blitz::max(negatives), blitz::max(positives));
double step = (max - min) / ((double)points - 1.0);
blitz::Array<double, 2> retval(2, points);
auto thresholds = bob::measure::meaningfulThresholds(
negatives, positives, points);
for (int i = 0; i < (int)points; ++i) {
std::pair<double, double> ratios =
bob::measure::precision_recall(negatives, positives, min + i * step);
auto ratios = bob::measure::precision_recall(negatives, positives, thresholds(i));
retval(0, i) = ratios.first;
retval(1, i) = ratios.second;
}
@@ -441,6 +490,7 @@ bob::measure::roc_for_far(const blitz::Array<double, 1> &negatives,
return retval;
}
/**
* The input to this function is a cumulative probability. The output from
* this function is the Normal deviate that corresponds to that probability.
@@ -511,9 +561,9 @@ double bob::measure::ppndf(double value) { return _ppndf(value); }
blitz::Array<double, 2>
bob::measure::det(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives, size_t points) {
const blitz::Array<double, 1> &positives, size_t points, int min_far) {
blitz::Array<double, 2> retval(2, points);
retval = blitz::_ppndf(bob::measure::roc(negatives, positives, points));
retval = blitz::_ppndf(bob::measure::roc(negatives, positives, points, min_far));
return retval;
}
Loading