From f59d0af1ba3a5d9f967cdcb262c81fb4a22663ff Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 20 Dec 2019 13:25:48 +0100
Subject: [PATCH] [test_trainer] Be more verbose when test fails

---
 bob/learn/libsvm/test_trainer.py | 25 ++++++++++++++-----------
 1 file changed, 14 insertions(+), 11 deletions(-)

diff --git a/bob/learn/libsvm/test_trainer.py b/bob/learn/libsvm/test_trainer.py
index 8a16f55..9f4c029 100644
--- a/bob/learn/libsvm/test_trainer.py
+++ b/bob/learn/libsvm/test_trainer.py
@@ -34,6 +34,10 @@ HEART_DATA = F('heart.svmdata') #13 inputs
 HEART_MACHINE = F('heart.svmmodel') #supports probabilities
 HEART_EXPECTED = F('heart.out') #expected probabilities
 
+def _check_abs_diff(a, b, maxval):
+  assert numpy.all(abs(a - b) < maxval), "Maximum " \
+          "difference exceeded limit (%g): %g" % (maxval, abs(a - b).max())
+
 def test_initialization():
 
   # tests and examplifies some initialization parameters
@@ -118,8 +122,8 @@ def test_training():
   nose.tools.eq_(machine.kernel_type, previous.kernel_type)
   assert numpy.isclose(machine.gamma, previous.gamma)
   nose.tools.eq_(machine.shape, previous.shape)
-  assert numpy.all(abs(machine.input_subtract - previous.input_subtract) < 1e-8)
-  assert numpy.all(abs(machine.input_divide - previous.input_divide) < 1e-8)
+  _check_abs_diff(machine.input_subtract, previous.input_subtract, 1e-8)
+  _check_abs_diff(machine.input_divide, previous.input_divide, 1e-8)
 
   curr_label = machine.predict_class(data)
   prev_label = previous.predict_class(data)
@@ -131,7 +135,7 @@ def test_training():
 
   curr_scores = numpy.array(curr_scores)
   prev_scores = numpy.array(prev_scores)
-  assert numpy.all(abs(curr_scores - prev_scores) < 1e-8)
+  _check_abs_diff(curr_scores, prev_scores, 1e-8)
 
 def test_training_with_probability():
 
@@ -152,8 +156,8 @@ def test_training_with_probability():
   nose.tools.eq_(machine.kernel_type, previous.kernel_type)
   assert numpy.isclose(machine.gamma, previous.gamma)
   nose.tools.eq_(machine.shape, previous.shape)
-  assert numpy.all(abs(machine.input_subtract - previous.input_subtract) < 1e-8)
-  assert numpy.all(abs(machine.input_divide - previous.input_divide) < 1e-8)
+  _check_abs_diff(machine.input_subtract, previous.input_subtract, 1e-8)
+  _check_abs_diff(machine.input_divide, previous.input_divide, 1e-8)
 
   # check labels
   curr_label = machine.predict_class(data)
@@ -167,7 +171,7 @@ def test_training_with_probability():
 
   curr_scores = numpy.array(curr_scores)
   prev_scores = numpy.array(prev_scores)
-  assert numpy.all(abs(curr_scores - prev_scores) < 1e-8)
+  _check_abs_diff(curr_scores, prev_scores, 1e-8)
 
   # check probabilities -- probA and probB do not get the exact same values
   # as when using libsvm's svm-train.c. The reason may lie in the order in
@@ -176,7 +180,7 @@ def test_training_with_probability():
   prev_labels, prev_scores = previous.predict_class_and_probabilities(data)
   curr_scores = numpy.array(curr_scores)
   prev_scores = numpy.array(prev_scores)
-  #assert numpy.all(abs(curr_scores-prev_scores) < 1e-8)
+  #_check_abs_diff(curr_scores, prev_scores, 1e-8)
 
 def test_training_one_class():
 
@@ -200,8 +204,8 @@ def test_training_one_class():
   nose.tools.eq_(machine.kernel_type, previous.kernel_type)
   assert numpy.isclose(machine.gamma, previous.gamma)
   nose.tools.eq_(machine.shape, previous.shape)
-  assert numpy.all(abs(machine.input_subtract - previous.input_subtract) < 1e-8)
-  assert numpy.all(abs(machine.input_divide - previous.input_divide) < 1e-8)
+  _check_abs_diff(machine.input_subtract, previous.input_subtract, 1e-8)
+  _check_abs_diff(machine.input_divide, previous.input_divide, 1e-8)
 
   curr_label = machine.predict_class(data)
   prev_label = previous.predict_class(data)
@@ -213,8 +217,7 @@ def test_training_one_class():
 
   curr_scores = numpy.array(curr_scores)
   prev_scores = numpy.array(prev_scores)
-  assert numpy.all(abs(curr_scores - prev_scores) < 1e-8)
-
+  _check_abs_diff(curr_scores, prev_scores)
 
 def test_successive_training():
 
-- 
GitLab