diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py
index ad11eb4d40561aa3f36038b241001ce0dbc8d8be..499e431160b5b167c5594f9dd8a85add614b87a9 100755
--- a/bob/learn/tensorflow/estimators/Logits.py
+++ b/bob/learn/tensorflow/estimators/Logits.py
@@ -124,7 +124,8 @@ class Logits(estimator.Estimator):
                 data, mode=mode, trainable_variables=trainable_variables)[0]
             logits = append_logits(prelogits, n_classes)
 
-            if self.embedding_validation:
+            if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
+
                 # Compute the embeddings
                 embeddings = tf.nn.l2_normalize(prelogits, 1)
                 predictions = {
@@ -146,7 +147,7 @@ class Logits(estimator.Estimator):
                 return tf.estimator.EstimatorSpec(
                     mode=mode, predictions=predictions)
 
-            if self.embedding_validation:
+            if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
                 predictions_op = predict_using_tensors(
                     predictions["embeddings"],
                     labels,