From ccac186ef9350bec406a9e16621767c824812827 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Tue, 10 Jul 2018 11:53:48 +0200 Subject: [PATCH] Fixed training accuracy issue --- bob/learn/tensorflow/estimators/Logits.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py index ad11eb4d..499e4311 100755 --- a/bob/learn/tensorflow/estimators/Logits.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -124,7 +124,8 @@ class Logits(estimator.Estimator): data, mode=mode, trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) - if self.embedding_validation: + if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN: + # Compute the embeddings embeddings = tf.nn.l2_normalize(prelogits, 1) predictions = { @@ -146,7 +147,7 @@ class Logits(estimator.Estimator): return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions) - if self.embedding_validation: + if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN: predictions_op = predict_using_tensors( predictions["embeddings"], labels, -- GitLab