Skip to content
Snippets Groups Projects
Commit 3b93cc90 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Use tensorflow global step variable

parent 15755a1c
No related branches found
No related tags found
1 merge request!9Resolve "exponential decay learning rate is not working"
......@@ -18,10 +18,10 @@ logger = logging.getLogger("bob.learn")
class SiameseTrainer(Trainer):
"""
Trainer for siamese networks:
Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to
face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005.
**Parameters**
......@@ -30,10 +30,10 @@ class SiameseTrainer(Trainer):
iterations:
Maximum number of iterations
snapshot:
Will take a snapshot of the network at every `n` iterations
validation_snapshot:
Test with validation each `n` iterations
......@@ -127,8 +127,7 @@ class SiameseTrainer(Trainer):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......
......@@ -25,7 +25,7 @@ logger = logging.getLogger("bob.learn")
class Trainer(object):
"""
One graph trainer.
Use this trainer when your CNN is composed by one graph
**Parameters**
......@@ -35,10 +35,10 @@ class Trainer(object):
iterations:
Maximum number of iterations
snapshot:
Will take a snapshot of the network at every `n` iterations
validation_snapshot:
Test with validation each `n` iterations
......@@ -118,15 +118,15 @@ class Trainer(object):
"""
Prepare all the tensorflow variables before training.
**Parameters**
graph: Input graph for training
optimizer: Solver
loss: Loss function
learning_rate: Learning rate
"""
......@@ -139,8 +139,7 @@ class Trainer(object):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......@@ -202,8 +201,8 @@ class Trainer(object):
Given a data shuffler prepared the dictionary to be injected in the graph
** Parameters **
data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
"""
[data, labels] = data_shuffler.get_batch()
......
......@@ -19,8 +19,8 @@ logger = logging.getLogger("bob.learn")
class TripletTrainer(Trainer):
"""
Trainer for Triple networks:
Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
"Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
**Parameters**
......@@ -30,10 +30,10 @@ class TripletTrainer(Trainer):
iterations:
Maximum number of iterations
snapshot:
Will take a snapshot of the network at every `n` iterations
validation_snapshot:
Test with validation each `n` iterations
......@@ -143,8 +143,7 @@ class TripletTrainer(Trainer):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......
......@@ -18,8 +18,8 @@ def exponential_decay(base_learning_rate=0.05,
staircase: Boolean. It True decay the learning rate at discrete intervals
"""
global_step = tf.Variable(0, trainable=False)
return tf.train.exponential_decay(base_learning_rate=base_learning_rate,
global_step = tf.contrib.framework.get_or_create_global_step()
return tf.train.exponential_decay(learning_rate=base_learning_rate,
global_step=global_step,
decay_steps=decay_steps,
decay_rate=weight_decay,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment