Commit ccac186e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed training accuracy issue

parent 896d0772
......@@ -124,7 +124,8 @@ class Logits(estimator.Estimator):
data, mode=mode, trainable_variables=trainable_variables)[0]
logits = append_logits(prelogits, n_classes)
if self.embedding_validation:
if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
......@@ -146,7 +147,7 @@ class Logits(estimator.Estimator):
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions)
if self.embedding_validation:
if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
predictions_op = predict_using_tensors(
predictions["embeddings"],
labels,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment