Skip to content
Snippets Groups Projects
Commit fe91ec3f authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

FIxed the order of the loss computation

parent 7205709b
Branches
Tags
1 merge request!27Estimator that loads variables from another model
Pipeline #
......@@ -140,20 +140,6 @@ class Logits(estimator.Estimator):
prelogits = self.architecture(data, is_trainable=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
......@@ -172,6 +158,21 @@ class Logits(estimator.Estimator):
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
# Validation
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
......@@ -289,23 +290,6 @@ class LogitsCenterLoss(estimator.Estimator):
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
# Loading variables from some model just in case
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
centers)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
# Compute the embeddings
......@@ -326,6 +310,25 @@ class LogitsCenterLoss(estimator.Estimator):
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
# Loading variables from some model just in case
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
centers)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment