Commit 3b93cc90 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Use tensorflow global step variable

parent 15755a1c
......@@ -18,10 +18,10 @@ logger = logging.getLogger("bob.learn")
class SiameseTrainer(Trainer):
"""
Trainer for siamese networks:
Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to
face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005.
**Parameters**
......@@ -30,10 +30,10 @@ class SiameseTrainer(Trainer):
iterations:
Maximum number of iterations
snapshot:
Will take a snapshot of the network at every `n` iterations
validation_snapshot:
Test with validation each `n` iterations
......@@ -127,8 +127,7 @@ class SiameseTrainer(Trainer):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......
......@@ -25,7 +25,7 @@ logger = logging.getLogger("bob.learn")
class Trainer(object):
"""
One graph trainer.
Use this trainer when your CNN is composed by one graph
**Parameters**
......@@ -35,10 +35,10 @@ class Trainer(object):
iterations:
Maximum number of iterations
snapshot:
Will take a snapshot of the network at every `n` iterations
validation_snapshot:
Test with validation each `n` iterations
......@@ -118,15 +118,15 @@ class Trainer(object):
"""
Prepare all the tensorflow variables before training.
**Parameters**
graph: Input graph for training
optimizer: Solver
loss: Loss function
learning_rate: Learning rate
"""
......@@ -139,8 +139,7 @@ class Trainer(object):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......@@ -202,8 +201,8 @@ class Trainer(object):
Given a data shuffler prepared the dictionary to be injected in the graph
** Parameters **
data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
"""
[data, labels] = data_shuffler.get_batch()
......
......@@ -19,8 +19,8 @@ logger = logging.getLogger("bob.learn")
class TripletTrainer(Trainer):
"""
Trainer for Triple networks:
Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
"Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
**Parameters**
......@@ -30,10 +30,10 @@ class TripletTrainer(Trainer):
iterations:
Maximum number of iterations
snapshot:
Will take a snapshot of the network at every `n` iterations
validation_snapshot:
Test with validation each `n` iterations
......@@ -143,8 +143,7 @@ class TripletTrainer(Trainer):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
self.global_step = tf.contrib.framework.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......
......@@ -18,8 +18,8 @@ def exponential_decay(base_learning_rate=0.05,
staircase: Boolean. It True decay the learning rate at discrete intervals
"""
global_step = tf.Variable(0, trainable=False)
return tf.train.exponential_decay(base_learning_rate=base_learning_rate,
global_step = tf.contrib.framework.get_or_create_global_step()
return tf.train.exponential_decay(learning_rate=base_learning_rate,
global_step=global_step,
decay_steps=decay_steps,
decay_rate=weight_decay,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment