Skip to content
Snippets Groups Projects
Commit 66fb1898 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Updated triplet

parent 467d4d05
1 merge request!32Organizing transfer learning
......@@ -121,13 +121,12 @@ class Triplet(estimator.Estimator):
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building one graph
prelogits_anchor = self.architecture(features['anchor'], is_trainable=is_trainable)[0]
prelogits_positive = self.architecture(features['positive'], reuse=True, is_trainable=is_trainable)[0]
prelogits_negative = self.architecture(features['negative'], reuse=True, is_trainable=is_trainable)[0]
prelogits_anchor = self.architecture(features['anchor'], is_training_mode = True)[0]
prelogits_positive = self.architecture(features['positive'], reuse=True, is_training_mode = True)[0]
prelogits_negative = self.architecture(features['negative'], reuse=True, is_training_mode = True)[0]
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_anchor, prelogits_positive, prelogits_negative)
......@@ -141,7 +140,7 @@ class Triplet(estimator.Estimator):
data = features['data']
# Compute the embeddings
prelogits = self.architecture(data)[0]
prelogits = self.architecture(data, is_training_mode = False)[0]
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {"embeddings": embeddings}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment