Commit 7e6a0704 authored by Manuel Günther's avatar Manuel Günther
Browse files

Changed test case for different corner cases

parent 89511eeb
...@@ -105,7 +105,7 @@ double bob::measure::eerRocch(const blitz::Array<double, 1> &negatives, ...@@ -105,7 +105,7 @@ double bob::measure::eerRocch(const blitz::Array<double, 1> &negatives,
} }
double bob::measure::farThreshold(const blitz::Array<double, 1> &negatives, double bob::measure::farThreshold(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &, const blitz::Array<double, 1> &positives,
double far_value, bool is_sorted) { double far_value, bool is_sorted) {
// check the parameters are valid // check the parameters are valid
if (far_value < 0. || far_value > 1.) { if (far_value < 0. || far_value > 1.) {
...@@ -155,21 +155,18 @@ double bob::measure::farThreshold(const blitz::Array<double, 1> &negatives, ...@@ -155,21 +155,18 @@ double bob::measure::farThreshold(const blitz::Array<double, 1> &negatives,
// move to the left of array changing the threshold until we pass the desired // move to the left of array changing the threshold until we pass the desired
// FAR value. // FAR value.
double threshold; double threshold = neg(index);
double future_far; double future_far;
while (index >= 0) { while (index > 0) {
threshold = neg(index);
if (index == 0)
break;
future_far = blitz::count(neg >= neg(index-1)) / (double)neg.extent(0); future_far = blitz::count(neg >= neg(index-1)) / (double)neg.extent(0);
if (future_far > far_value) if (future_far > far_value)
break; break;
--index; threshold = neg(--index);
} }
return threshold; return threshold;
} }
double bob::measure::frrThreshold(const blitz::Array<double, 1> &, double bob::measure::frrThreshold(const blitz::Array<double, 1> &negatives,
const blitz::Array<double, 1> &positives, const blitz::Array<double, 1> &positives,
double frr_value, bool is_sorted) { double frr_value, bool is_sorted) {
...@@ -221,16 +218,13 @@ double bob::measure::frrThreshold(const blitz::Array<double, 1> &, ...@@ -221,16 +218,13 @@ double bob::measure::frrThreshold(const blitz::Array<double, 1> &,
// move to the right of array changing the threshold until we pass the // move to the right of array changing the threshold until we pass the
// desired FRR value. // desired FRR value.
double threshold; double threshold = pos(index);
double future_frr; double future_frr;
while (index < pos.extent(0)) { while (index < pos.extent(0)-1) {
threshold = pos(index);
if (index == pos.extent(0)-1)
break;
future_frr = blitz::count(pos < pos(index+1)) / (double)pos.extent(0); future_frr = blitz::count(pos < pos(index+1)) / (double)pos.extent(0);
if (future_frr > frr_value) if (future_frr > frr_value)
break; break;
++index; threshold = pos(++index);
} }
return threshold; return threshold;
} }
......
...@@ -169,10 +169,10 @@ def test_obvious_thresholds(): ...@@ -169,10 +169,10 @@ def test_obvious_thresholds():
neg = numpy.array(range(M), dtype=float) neg = numpy.array(range(M), dtype=float)
pos = numpy.array(range(M, 2 * M), dtype=float) pos = numpy.array(range(M, 2 * M), dtype=float)
for far, frr in zip(numpy.array(range(0, M + 1), dtype=float) / neg.size, for far, frr in zip(numpy.array(range(0, 2*M + 1), dtype=float) / neg.size/2,
numpy.array(range(0, M + 1), dtype=float) / pos.size): numpy.array(range(0, 2*M + 1), dtype=float) / pos.size/2):
far = round(far, int(M / 10)) far, expected_far = round(far, 2), math.floor(far*10)/10
frr = round(frr, int(M / 10)) frr, expected_frr = round(frr, 2), math.floor(frr*10)/10
calculated_far_threshold = far_threshold(neg, pos, far) calculated_far_threshold = far_threshold(neg, pos, far)
predicted_far, _ = farfrr( predicted_far, _ = farfrr(
neg, pos, calculated_far_threshold) neg, pos, calculated_far_threshold)
...@@ -180,8 +180,10 @@ def test_obvious_thresholds(): ...@@ -180,8 +180,10 @@ def test_obvious_thresholds():
calculated_frr_threshold = frr_threshold(neg, pos, frr) calculated_frr_threshold = frr_threshold(neg, pos, frr)
_, predicted_frr = farfrr( _, predicted_frr = farfrr(
neg, pos, calculated_frr_threshold) neg, pos, calculated_frr_threshold)
assert predicted_far <= far, (far, calculated_far_threshold, predicted_far) assert predicted_far <= far, (predicted_far, far, calculated_far_threshold)
assert predicted_frr <= frr, (frr, calculated_frr_threshold, predicted_frr) assert predicted_far == expected_far
assert predicted_frr <= frr, (predicted_frr, frr, calculated_frr_threshold)
assert predicted_frr == expected_frr
def test_thresholding(): def test_thresholding():
...@@ -220,11 +222,11 @@ def test_thresholding(): ...@@ -220,11 +222,11 @@ def test_thresholding():
frr = farfrr(negatives, positives, threshold_frr)[1] frr = farfrr(negatives, positives, threshold_frr)[1]
if not math.isnan(threshold_far): if not math.isnan(threshold_far):
assert far <= t, (far, t) assert far <= t, (far, t)
assert t - far <= 0.1, (far, t) assert far - t <= 0.1
if not math.isnan(threshold_frr): if not math.isnan(threshold_frr):
assert frr <= t, (frr, t) assert frr <= t, (frr, t)
# test that the values are at least somewhere in the range # test that the values are at least somewhere in the range
assert t - frr <= 0.1, (frr, t) assert frr - t <= 0.1
# If the set is separable, the calculation of the threshold is a little bit # If the set is separable, the calculation of the threshold is a little bit
# trickier, as you have no points in the middle of the range to compare # trickier, as you have no points in the middle of the range to compare
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment