Skip to content
Snippets Groups Projects

Estimator that loads variables from another model

Merged Tiago de Freitas Pereira requested to merge issue-44 into master
All threads resolved!
1 file
+ 34
31
Compare changes
  • Side-by-side
  • Inline
@@ -140,20 +140,6 @@ class Logits(estimator.Estimator):
prelogits = self.architecture(data, is_trainable=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
@@ -172,6 +158,21 @@ class Logits(estimator.Estimator):
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
# Validation
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
@@ -289,23 +290,6 @@ class LogitsCenterLoss(estimator.Estimator):
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
# Loading variables from some model just in case
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
centers)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
# Compute the embeddings
@@ -326,6 +310,25 @@ class LogitsCenterLoss(estimator.Estimator):
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
# Loading variables from some model just in case
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
centers)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
Loading