From 77f78c042e8c494866ff0cf52f584fbc87048292 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Sat, 28 Jul 2018 16:37:02 +0200
Subject: [PATCH] For some reason tensorflow 1.8 was casting to uint8 the
 integer division.

---
 bob/learn/tensorflow/test/test_utils.py | 7 ++++---
 bob/learn/tensorflow/utils/util.py      | 2 +-
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/bob/learn/tensorflow/test/test_utils.py b/bob/learn/tensorflow/test/test_utils.py
index e4f61241..1e897f96 100644
--- a/bob/learn/tensorflow/test/test_utils.py
+++ b/bob/learn/tensorflow/test/test_utils.py
@@ -27,7 +27,6 @@ def test_embedding_accuracy():
 
     data = numpy.vstack((class_a, class_b))
     labels = numpy.concatenate((labels_a, labels_b))
-
     assert compute_embedding_accuracy(data, labels) == 1.
 
     # Adding noise
@@ -51,7 +50,9 @@ def test_embedding_accuracy_tensors():
 
     class_b = numpy.random.normal(
         loc=10, scale=0.1, size=(samples_per_class, 2))
-    labels_b = numpy.ones(samples_per_class)
+    class_b = numpy.vstack((class_b, numpy.array([0,0.])))# Adding outlier
+    labels_b = numpy.ones(samples_per_class + 1)
+    
 
     data = numpy.vstack((class_a, class_b))
     labels = numpy.concatenate((labels_a, labels_b))
@@ -61,4 +62,4 @@ def test_embedding_accuracy_tensors():
 
     sess = tf.Session()
     accuracy = sess.run(compute_embedding_accuracy_tensors(data, labels))
-    assert accuracy == 1.
+    assert abs(accuracy-7/11.) < 10e-3
diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py
index a2fd7ebe..9f772bb7 100644
--- a/bob/learn/tensorflow/utils/util.py
+++ b/bob/learn/tensorflow/utils/util.py
@@ -226,7 +226,7 @@ def compute_embedding_accuracy_tensors(embedding, labels, num=None):
             tf.unstack(predictions, num=num), tf.unstack(labels, num=num))
     ]
 
-    return tf.reduce_sum(tf.cast(matching, tf.uint8)) / len(predictions)
+    return tf.reduce_sum(tf.cast(matching, tf.float32)) / len(predictions)
 
 
 def compute_embedding_accuracy(embedding, labels):
-- 
GitLab