Commit 82a80a54 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Use tf.train.get_or_create_global_step (introduced in 1.2) instead of the contrib one

parent 20d5c20d
Pipeline #14647 failed with stages
in 30 minutes and 7 seconds
......@@ -117,7 +117,7 @@ class Logits(estimator.Estimator):
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building the training graph
# Building the training graph
prelogits = self.architecture(data, mode=mode, trainable_variables=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
......@@ -128,11 +128,11 @@ class Logits(estimator.Estimator):
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.train.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
# Building the training graph for PREDICTION OR VALIDATION
prelogits = self.architecture(data, mode=mode, trainable_variables=False)[0]
......@@ -162,7 +162,7 @@ class Logits(estimator.Estimator):
# IF Validation
self.loss = self.loss_op(logits, labels)
if self.embedding_validation:
predictions_op = predict_using_tensors(
predictions["embeddings"], labels,
......@@ -286,7 +286,7 @@ class LogitsCenterLoss(estimator.Estimator):
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building the training graph
# Building the training graph
prelogits = self.architecture(data, mode=mode, trainable_variables=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
......@@ -334,15 +334,15 @@ class LogitsCenterLoss(estimator.Estimator):
# IF Validation
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
self.loss = loss_dict['loss']
if self.embedding_validation:
predictions_op = predict_using_tensors(
predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions_op)}
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
else:
# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment