diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py index ad11eb4d40561aa3f36038b241001ce0dbc8d8be..499e431160b5b167c5594f9dd8a85add614b87a9 100755 --- a/bob/learn/tensorflow/estimators/Logits.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -124,7 +124,8 @@ class Logits(estimator.Estimator): data, mode=mode, trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) - if self.embedding_validation: + if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN: + # Compute the embeddings embeddings = tf.nn.l2_normalize(prelogits, 1) predictions = { @@ -146,7 +147,7 @@ class Logits(estimator.Estimator): return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions) - if self.embedding_validation: + if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN: predictions_op = predict_using_tensors( predictions["embeddings"], labels,