Skip to content
Snippets Groups Projects
Commit ccac186e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed training accuracy issue

parent 896d0772
Branches
Tags
1 merge request!57Updates to the logits estimator
......@@ -124,7 +124,8 @@ class Logits(estimator.Estimator):
data, mode=mode, trainable_variables=trainable_variables)[0]
logits = append_logits(prelogits, n_classes)
if self.embedding_validation:
if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
......@@ -146,7 +147,7 @@ class Logits(estimator.Estimator):
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions)
if self.embedding_validation:
if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
predictions_op = predict_using_tensors(
predictions["embeddings"],
labels,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment