Skip to content
Snippets Groups Projects
Commit ccac186e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed training accuracy issue

parent 896d0772
No related branches found
No related tags found
1 merge request!57Updates to the logits estimator
...@@ -124,7 +124,8 @@ class Logits(estimator.Estimator): ...@@ -124,7 +124,8 @@ class Logits(estimator.Estimator):
data, mode=mode, trainable_variables=trainable_variables)[0] data, mode=mode, trainable_variables=trainable_variables)[0]
logits = append_logits(prelogits, n_classes) logits = append_logits(prelogits, n_classes)
if self.embedding_validation: if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
# Compute the embeddings # Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1) embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = { predictions = {
...@@ -146,7 +147,7 @@ class Logits(estimator.Estimator): ...@@ -146,7 +147,7 @@ class Logits(estimator.Estimator):
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions) mode=mode, predictions=predictions)
if self.embedding_validation: if self.embedding_validation and mode != tf.estimator.ModeKeys.TRAIN:
predictions_op = predict_using_tensors( predictions_op = predict_using_tensors(
predictions["embeddings"], predictions["embeddings"],
labels, labels,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment