diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index 791c317f229d39e06ebe2a9e97bf732592412228..6f512eb5717ed3dba05c5c4b12579ad6a0faccb8 100644 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -18,10 +18,10 @@ logger = logging.getLogger("bob.learn") class SiameseTrainer(Trainer): """ Trainer for siamese networks: - + Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to face verification." 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'05). Vol. 1. IEEE, 2005. - + **Parameters** @@ -30,10 +30,10 @@ class SiameseTrainer(Trainer): iterations: Maximum number of iterations - + snapshot: Will take a snapshot of the network at every `n` iterations - + validation_snapshot: Test with validation each `n` iterations @@ -127,8 +127,7 @@ class SiameseTrainer(Trainer): self.optimizer_class = optimizer self.learning_rate = learning_rate - # TODO: find an elegant way to provide this as a parameter of the trainer - self.global_step = tf.Variable(0, trainable=False, name="global_step") + self.global_step = tf.contrib.framework.get_or_create_global_step() # Saving all the variables self.saver = tf.train.Saver(var_list=tf.global_variables()) diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index 4f26b1d7572acf203bedbf25db4e919f0bd01015..e3a5c6d7d6bfe1d6946d2b8b70dbadcae3ccb61a 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -25,7 +25,7 @@ logger = logging.getLogger("bob.learn") class Trainer(object): """ One graph trainer. - + Use this trainer when your CNN is composed by one graph **Parameters** @@ -35,10 +35,10 @@ class Trainer(object): iterations: Maximum number of iterations - + snapshot: Will take a snapshot of the network at every `n` iterations - + validation_snapshot: Test with validation each `n` iterations @@ -118,15 +118,15 @@ class Trainer(object): """ Prepare all the tensorflow variables before training. - + **Parameters** - + graph: Input graph for training - + optimizer: Solver - + loss: Loss function - + learning_rate: Learning rate """ @@ -139,8 +139,7 @@ class Trainer(object): self.optimizer_class = optimizer self.learning_rate = learning_rate - # TODO: find an elegant way to provide this as a parameter of the trainer - self.global_step = tf.Variable(0, trainable=False, name="global_step") + self.global_step = tf.contrib.framework.get_or_create_global_step() # Saving all the variables self.saver = tf.train.Saver(var_list=tf.global_variables()) @@ -202,8 +201,8 @@ class Trainer(object): Given a data shuffler prepared the dictionary to be injected in the graph ** Parameters ** - - data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base` + + data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base` """ [data, labels] = data_shuffler.get_batch() diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py index 651baa11fd7cc3d2248849d99185f3bf5c3f2e7c..e15e90731023dded1b3841632dc6712d2c9e038d 100644 --- a/bob/learn/tensorflow/trainers/TripletTrainer.py +++ b/bob/learn/tensorflow/trainers/TripletTrainer.py @@ -19,8 +19,8 @@ logger = logging.getLogger("bob.learn") class TripletTrainer(Trainer): """ Trainer for Triple networks: - - Schroff, Florian, Dmitry Kalenichenko, and James Philbin. + + Schroff, Florian, Dmitry Kalenichenko, and James Philbin. "Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015. **Parameters** @@ -30,10 +30,10 @@ class TripletTrainer(Trainer): iterations: Maximum number of iterations - + snapshot: Will take a snapshot of the network at every `n` iterations - + validation_snapshot: Test with validation each `n` iterations @@ -143,8 +143,7 @@ class TripletTrainer(Trainer): self.optimizer_class = optimizer self.learning_rate = learning_rate - # TODO: find an elegant way to provide this as a parameter of the trainer - self.global_step = tf.Variable(0, trainable=False, name="global_step") + self.global_step = tf.contrib.framework.get_or_create_global_step() # Saving all the variables self.saver = tf.train.Saver(var_list=tf.global_variables()) diff --git a/bob/learn/tensorflow/trainers/learning_rate.py b/bob/learn/tensorflow/trainers/learning_rate.py index 9a206f75c1b0adaec02ec7f1b7b37c5fd265dd85..a8b0c46725259112a1882691934216256dafb5e1 100644 --- a/bob/learn/tensorflow/trainers/learning_rate.py +++ b/bob/learn/tensorflow/trainers/learning_rate.py @@ -18,8 +18,8 @@ def exponential_decay(base_learning_rate=0.05, staircase: Boolean. It True decay the learning rate at discrete intervals """ - global_step = tf.Variable(0, trainable=False) - return tf.train.exponential_decay(base_learning_rate=base_learning_rate, + global_step = tf.contrib.framework.get_or_create_global_step() + return tf.train.exponential_decay(learning_rate=base_learning_rate, global_step=global_step, decay_steps=decay_steps, decay_rate=weight_decay,