Skip to content
Snippets Groups Projects

Use tf.train.get_or_create_global_step (introduced in 1.2) instead of the contrib one

Merged Amir MOHAMMADI requested to merge reproducible into master
1 file
+ 8
8
Compare changes
  • Side-by-side
  • Inline
@@ -117,7 +117,7 @@ class Logits(estimator.Estimator):
@@ -117,7 +117,7 @@ class Logits(estimator.Estimator):
else:
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building the training graph
# Building the training graph
prelogits = self.architecture(data, mode=mode, trainable_variables=is_trainable)[0]
prelogits = self.architecture(data, mode=mode, trainable_variables=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
logits = append_logits(prelogits, n_classes)
@@ -128,11 +128,11 @@ class Logits(estimator.Estimator):
@@ -128,11 +128,11 @@ class Logits(estimator.Estimator):
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.train.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
train_op=train_op)
# Building the training graph for PREDICTION OR VALIDATION
# Building the training graph for PREDICTION OR VALIDATION
prelogits = self.architecture(data, mode=mode, trainable_variables=False)[0]
prelogits = self.architecture(data, mode=mode, trainable_variables=False)[0]
@@ -162,7 +162,7 @@ class Logits(estimator.Estimator):
@@ -162,7 +162,7 @@ class Logits(estimator.Estimator):
# IF Validation
# IF Validation
self.loss = self.loss_op(logits, labels)
self.loss = self.loss_op(logits, labels)
if self.embedding_validation:
if self.embedding_validation:
predictions_op = predict_using_tensors(
predictions_op = predict_using_tensors(
predictions["embeddings"], labels,
predictions["embeddings"], labels,
@@ -286,7 +286,7 @@ class LogitsCenterLoss(estimator.Estimator):
@@ -286,7 +286,7 @@ class LogitsCenterLoss(estimator.Estimator):
else:
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building the training graph
# Building the training graph
prelogits = self.architecture(data, mode=mode, trainable_variables=is_trainable)[0]
prelogits = self.architecture(data, mode=mode, trainable_variables=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
logits = append_logits(prelogits, n_classes)
@@ -334,15 +334,15 @@ class LogitsCenterLoss(estimator.Estimator):
@@ -334,15 +334,15 @@ class LogitsCenterLoss(estimator.Estimator):
# IF Validation
# IF Validation
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
self.loss = loss_dict['loss']
if self.embedding_validation:
if self.embedding_validation:
predictions_op = predict_using_tensors(
predictions_op = predict_using_tensors(
predictions["embeddings"], labels, num=validation_batch_size)
predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(
eval_metric_ops = {"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions_op)}
labels=labels, predictions=predictions_op)}
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
else:
else:
# Add evaluation metrics (for EVAL mode)
# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
eval_metric_ops = {
Loading