Commit 3b93cc90 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Use tensorflow global step variable

parent 15755a1c
...@@ -18,10 +18,10 @@ logger = logging.getLogger("bob.learn") ...@@ -18,10 +18,10 @@ logger = logging.getLogger("bob.learn")
class SiameseTrainer(Trainer): class SiameseTrainer(Trainer):
""" """
Trainer for siamese networks: Trainer for siamese networks:
Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to
face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005. face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005.
**Parameters** **Parameters**
...@@ -30,10 +30,10 @@ class SiameseTrainer(Trainer): ...@@ -30,10 +30,10 @@ class SiameseTrainer(Trainer):
iterations: iterations:
Maximum number of iterations Maximum number of iterations
snapshot: snapshot:
Will take a snapshot of the network at every `n` iterations Will take a snapshot of the network at every `n` iterations
validation_snapshot: validation_snapshot:
Test with validation each `n` iterations Test with validation each `n` iterations
...@@ -127,8 +127,7 @@ class SiameseTrainer(Trainer): ...@@ -127,8 +127,7 @@ class SiameseTrainer(Trainer):
self.optimizer_class = optimizer self.optimizer_class = optimizer
self.learning_rate = learning_rate self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.Variable(0, trainable=False, name="global_step")
# Saving all the variables # Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables()) self.saver = tf.train.Saver(var_list=tf.global_variables())
......
...@@ -25,7 +25,7 @@ logger = logging.getLogger("bob.learn") ...@@ -25,7 +25,7 @@ logger = logging.getLogger("bob.learn")
class Trainer(object): class Trainer(object):
""" """
One graph trainer. One graph trainer.
Use this trainer when your CNN is composed by one graph Use this trainer when your CNN is composed by one graph
**Parameters** **Parameters**
...@@ -35,10 +35,10 @@ class Trainer(object): ...@@ -35,10 +35,10 @@ class Trainer(object):
iterations: iterations:
Maximum number of iterations Maximum number of iterations
snapshot: snapshot:
Will take a snapshot of the network at every `n` iterations Will take a snapshot of the network at every `n` iterations
validation_snapshot: validation_snapshot:
Test with validation each `n` iterations Test with validation each `n` iterations
...@@ -118,15 +118,15 @@ class Trainer(object): ...@@ -118,15 +118,15 @@ class Trainer(object):
""" """
Prepare all the tensorflow variables before training. Prepare all the tensorflow variables before training.
**Parameters** **Parameters**
graph: Input graph for training graph: Input graph for training
optimizer: Solver optimizer: Solver
loss: Loss function loss: Loss function
learning_rate: Learning rate learning_rate: Learning rate
""" """
...@@ -139,8 +139,7 @@ class Trainer(object): ...@@ -139,8 +139,7 @@ class Trainer(object):
self.optimizer_class = optimizer self.optimizer_class = optimizer
self.learning_rate = learning_rate self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.Variable(0, trainable=False, name="global_step")
# Saving all the variables # Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables()) self.saver = tf.train.Saver(var_list=tf.global_variables())
...@@ -202,8 +201,8 @@ class Trainer(object): ...@@ -202,8 +201,8 @@ class Trainer(object):
Given a data shuffler prepared the dictionary to be injected in the graph Given a data shuffler prepared the dictionary to be injected in the graph
** Parameters ** ** Parameters **
data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base` data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
""" """
[data, labels] = data_shuffler.get_batch() [data, labels] = data_shuffler.get_batch()
......
...@@ -19,8 +19,8 @@ logger = logging.getLogger("bob.learn") ...@@ -19,8 +19,8 @@ logger = logging.getLogger("bob.learn")
class TripletTrainer(Trainer): class TripletTrainer(Trainer):
""" """
Trainer for Triple networks: Trainer for Triple networks:
Schroff, Florian, Dmitry Kalenichenko, and James Philbin. Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
"Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015. "Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
**Parameters** **Parameters**
...@@ -30,10 +30,10 @@ class TripletTrainer(Trainer): ...@@ -30,10 +30,10 @@ class TripletTrainer(Trainer):
iterations: iterations:
Maximum number of iterations Maximum number of iterations
snapshot: snapshot:
Will take a snapshot of the network at every `n` iterations Will take a snapshot of the network at every `n` iterations
validation_snapshot: validation_snapshot:
Test with validation each `n` iterations Test with validation each `n` iterations
...@@ -143,8 +143,7 @@ class TripletTrainer(Trainer): ...@@ -143,8 +143,7 @@ class TripletTrainer(Trainer):
self.optimizer_class = optimizer self.optimizer_class = optimizer
self.learning_rate = learning_rate self.learning_rate = learning_rate
# TODO: find an elegant way to provide this as a parameter of the trainer self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.Variable(0, trainable=False, name="global_step")
# Saving all the variables # Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables()) self.saver = tf.train.Saver(var_list=tf.global_variables())
......
...@@ -18,8 +18,8 @@ def exponential_decay(base_learning_rate=0.05, ...@@ -18,8 +18,8 @@ def exponential_decay(base_learning_rate=0.05,
staircase: Boolean. It True decay the learning rate at discrete intervals staircase: Boolean. It True decay the learning rate at discrete intervals
""" """
global_step = tf.Variable(0, trainable=False) global_step = tf.contrib.framework.get_or_create_global_step()
return tf.train.exponential_decay(base_learning_rate=base_learning_rate, return tf.train.exponential_decay(learning_rate=base_learning_rate,
global_step=global_step, global_step=global_step,
decay_steps=decay_steps, decay_steps=decay_steps,
decay_rate=weight_decay, decay_rate=weight_decay,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment