Commit 3b93cc90 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Use tensorflow global step variable

parent 15755a1c
......@@ -127,8 +127,7 @@ class SiameseTrainer(Trainer):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......
......@@ -139,8 +139,7 @@ class Trainer(object):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......
......@@ -143,8 +143,7 @@ class TripletTrainer(Trainer):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......
......@@ -18,8 +18,8 @@ def exponential_decay(base_learning_rate=0.05,
staircase: Boolean. It True decay the learning rate at discrete intervals
"""
global_step = tf.Variable(0, trainable=False)
return tf.train.exponential_decay(base_learning_rate=base_learning_rate,
global_step = tf.contrib.framework.get_or_create_global_step()
return tf.train.exponential_decay(learning_rate=base_learning_rate,
global_step=global_step,
decay_steps=decay_steps,
decay_rate=weight_decay,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment