diff --git a/bob/learn/tensorflow/test/test_onegraph_model_fn.py b/bob/learn/tensorflow/test/test_onegraph_model_fn.py
index c6f4d54f81a03ea6e8ad219a197a34f1e9381ed2..fadf49dc9725ce24956f8e4d31c2577922b3ec2e 100755
--- a/bob/learn/tensorflow/test/test_onegraph_model_fn.py
+++ b/bob/learn/tensorflow/test/test_onegraph_model_fn.py
@@ -63,28 +63,28 @@ def run_cnn(embedding_validation):
                                 optimizer=tf.train.GradientDescentOptimizer(learning_rate),
                                 n_classes=10,
                                 loss_op=mean_cross_entropy_loss,
-                                embedding_validation=embedding_validation)
+                                embedding_validation=embedding_validation,
+                                validation_batch_size=validation_batch_size
+                                )
 
-        data, labels = shuffle_data_and_labels([tfrecord_train], data_shape, data_type, batch_size, epochs=epochs)                            
         def input_fn():
             return shuffle_data_and_labels(tfrecord_train, data_shape, data_type,
-                                       batch_size, epochs=epochs)
+                                           batch_size, epochs=epochs)
                                        
         def input_fn_validation():
             return batch_data_and_labels(tfrecord_validation, data_shape, data_type,
                                          validation_batch_size, epochs=epochs)                                       
 
         hooks = [LoggerHookEstimator(trainer, 16, 100)]
-        trainer.train(input_fn, steps=steps, hooks = hooks)
+        trainer.train(input_fn, steps=steps, hooks=hooks)
         
-        # TODO: REMOVE THIS HACK
         if not embedding_validation:
             acc = trainer.evaluate(input_fn_validation)
-            assert acc > 0.80
+            assert acc['accuracy'] > 0.80
         else:
-            assert True
-          
-        
+            acc = trainer.evaluate(input_fn_validation)
+            assert acc['accuracy'] > 0.80
+
     finally:
         try:
             os.unlink(tfrecord_train)
diff --git a/bob/learn/tensorflow/test/test_utils.py b/bob/learn/tensorflow/test/test_utils.py
index 966fd66c0036ce00feba0b9c82d411b8615b0f26..23f3ded505b260b9cf8f7cb3d3e224fb8853bb17 100755
--- a/bob/learn/tensorflow/test/test_utils.py
+++ b/bob/learn/tensorflow/test/test_utils.py
@@ -3,7 +3,8 @@
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
 
 import numpy
-from bob.learn.tensorflow.utils import compute_embedding_accuracy, cdist, compute_embedding_accuracy_tensors
+from bob.learn.tensorflow.utils import compute_embedding_accuracy, cdist,\
+     compute_embedding_accuracy_tensors, predict_using_tensors
 
 import tensorflow as tf
 
@@ -11,6 +12,7 @@ import tensorflow as tf
 Some unit tests for the datashuffler
 """
 
+
 def test_embedding_accuracy():
 
     numpy.random.seed(10)
@@ -53,8 +55,7 @@ def test_embedding_accuracy_tensors():
     
     data = tf.convert_to_tensor(data.astype("float32"))
     labels = tf.convert_to_tensor(labels.astype("int64"))
-    
+
     sess = tf.Session()
     accuracy = sess.run(compute_embedding_accuracy_tensors(data, labels))
     assert accuracy == 1.
-
diff --git a/bob/learn/tensorflow/trainers/LogitsTrainer.py b/bob/learn/tensorflow/trainers/LogitsTrainer.py
index c4b63d24378e48c51420c8c36c534d32b354c226..7204396701590e98779e61d0f88b217641c8cf62 100755
--- a/bob/learn/tensorflow/trainers/LogitsTrainer.py
+++ b/bob/learn/tensorflow/trainers/LogitsTrainer.py
@@ -15,7 +15,7 @@ import time
 from bob.learn.tensorflow.network.utils import append_logits
 from tensorflow.python.estimator import estimator
 from bob.learn.tensorflow.utils import reproducible
-from bob.learn.tensorflow.utils import compute_embedding_accuracy_tensors
+from bob.learn.tensorflow.utils import predict_using_tensors
 
 
 import logging
@@ -48,6 +48,10 @@ class LogitsTrainer(estimator.Estimator):
       
       model_dir:
         Model path
+
+      validation_batch_size:
+        Size of the batch for validation. This value is used when the
+        validation with embeddings is used. This is a hack.
     """
 
     def __init__(self,
@@ -58,6 +62,7 @@ class LogitsTrainer(estimator.Estimator):
                  loss_op=None,
                  embedding_validation=False,
                  model_dir="",
+                 validation_batch_size=None,
               ):
 
         self.architecture = architecture
@@ -79,7 +84,6 @@ class LogitsTrainer(estimator.Estimator):
         if self.n_classes <=0:
             raise ValueError("Number of classes must be greated than 0")
 
-
         def _model_fn(features, labels, mode, params, config):
             
             # Building one graph
@@ -90,7 +94,7 @@ class LogitsTrainer(estimator.Estimator):
                 # Compute the embeddings
                 embeddings = tf.nn.l2_normalize(prelogits, 1)
                 predictions = {
-                    "embeddings":embeddings                    
+                    "embeddings": embeddings
                 }
                 
             else:
@@ -107,7 +111,6 @@ class LogitsTrainer(estimator.Estimator):
 
             # Calculate Loss (for both TRAIN and EVAL modes)
             self.loss = self.loss_op(logits, labels)
-            
 
             # Configure the Training Op (for TRAIN mode)
             if mode == tf.estimator.ModeKeys.TRAIN:
@@ -117,8 +120,8 @@ class LogitsTrainer(estimator.Estimator):
                                                   train_op=train_op)
 
             if self.embedding_validation:
-                #eval_metric_ops = {"accuracy": compute_embedding_accuracy_tensors(predictions["embeddings"], labels)}
-                eval_metric_ops = {} # TODO: I still don't know how to compute this with an unknown size
+                predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
+                eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
                 return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
             
             else:
diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py
index 7aae42fa20e19b5aebad1748d809350ec5156e4c..d3a9b9943e82907629491205f0407cd149e43437 100755
--- a/bob/learn/tensorflow/utils/util.py
+++ b/bob/learn/tensorflow/utils/util.py
@@ -46,7 +46,6 @@ def load_mnist(perc_train=0.9):
     return train_data, train_labels, validation_data, validation_labels
 
 
-
 def create_mnist_tfrecord(tfrecords_filename, data, labels, n_samples=6000):
 
     def _bytes_feature(value):
@@ -150,43 +149,59 @@ def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"):
         embeddings[i] = embedding
 
     return embeddings
-    
 
 
 def cdist(A):
+    """
+    Compute a pairwise euclidean distance in the same fashion
+    as in scipy.spation.distance.cdist
+    """
     with tf.variable_scope('Pairwisedistance'):
+        #ones_1 = tf.ones(shape=(1, A.shape.as_list()[0]))
+        ones_1 = tf.reshape(tf.cast(tf.ones_like(A), tf.float32)[:, 0], [1, -1])
         p1 = tf.matmul(
             tf.expand_dims(tf.reduce_sum(tf.square(A), 1), 1),
-            tf.ones(shape=(1, A.shape.as_list()[0]))
+            ones_1
         )
+
+        #ones_2 = tf.ones(shape=(A.shape.as_list()[0], 1))
+        ones_2 = tf.reshape(tf.cast(tf.ones_like(A), tf.float32)[:, 0], [-1, 1])
         p2 = tf.transpose(tf.matmul(
             tf.reshape(tf.reduce_sum(tf.square(A), 1), shape=[-1, 1]),
-            tf.ones(shape=(A.shape.as_list()[0], 1)),
+            ones_2,
             transpose_b=True
         ))
 
         return tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(A, A, transpose_b=True))
 
 
-def compute_embedding_accuracy_tensors(embedding, labels):
+def predict_using_tensors(embedding, labels, num=None):
     """
-    Compute the accuracy through exhaustive comparisons between the embeddings using tensors
+    Compute the predictions through exhaustive comparisons between
+    embeddings using tensors
     """
-    
-    distances = cdist(embedding)
 
     # Fitting the main diagonal with infs (removing comparisons with the same sample)
-    inf = numpy.ones(10)*numpy.inf
-    inf = inf.astype("float32")
+    inf = tf.cast(tf.ones_like(labels), tf.float32) * numpy.inf
 
     distances = cdist(embedding)
     distances = tf.matrix_set_diag(distances, inf)
     indexes = tf.argmin(distances, axis=1)
+    return [labels[i] for i in tf.unstack(indexes, num=num)]
+
+
+def compute_embedding_accuracy_tensors(embedding, labels, num=None):
+    """
+    Compute the accuracy through exhaustive comparisons between the embeddings using tensors
+    """
+
+    # Fitting the main diagonal with infs (removing comparisons with the same sample)
+    predictions = predict_using_tensors(embedding, labels, num=num)
+    matching = [tf.equal(p, l) for p, l in zip(tf.unstack(predictions, num=num), tf.unstack(labels, num=num))]
+
+    return tf.reduce_sum(tf.cast(matching, tf.uint8))/len(predictions)
 
-    matching = [ tf.equal(labels[i],labels[j]) for i,j in zip(range(indexes.get_shape().as_list()[0]), tf.unstack(indexes))]
-    return tf.reduce_sum(tf.cast(matching, tf.uint8))/indexes.get_shape().as_list()[0]
 
-    
 def compute_embedding_accuracy(embedding, labels):
     """
     Compute the accuracy through exhaustive comparisons between the embeddings