diff --git a/bob/learn/tensorflow/script/train_mobio.py b/bob/learn/tensorflow/script/train_mobio.py index 6c45616191c386d2e143d519ccd7dd96f42e39dd..8b46e5dcdff452f508114d518984b9b983dc1a4b 100644 --- a/bob/learn/tensorflow/script/train_mobio.py +++ b/bob/learn/tensorflow/script/train_mobio.py @@ -101,7 +101,7 @@ def main(): #optimizer = optimizer, trainer = TripletTrainer(architecture=architecture, loss=loss, iterations=ITERATIONS, - base_learning_rate=0.0001, + base_learning_rate=0.1, prefetch=False, temp_dir="./LOGS_MOBIO/triplet-cnn") diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index e7e6900e951f7012f56ace7e49599ad05a4954ee..46784419654f562a4ad3bd4b75f6b6e24534a51f 100644 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -50,6 +50,7 @@ class SiameseTrainer(Trainer): # Learning rate base_learning_rate=0.001, weight_decay=0.9, + decay_steps=1000, ###### training options ########## convergence_threshold=0.01, @@ -72,6 +73,7 @@ class SiameseTrainer(Trainer): # Learning rate base_learning_rate=base_learning_rate, weight_decay=weight_decay, + decay_steps=decay_steps, ###### training options ########## convergence_threshold=convergence_threshold, diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index affd918831dc1066e0195a045ec8520a76982f9d..9db18bc9bb6b4cab06f1937ecbc85d15686c8074 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -51,8 +51,9 @@ class Trainer(object): temp_dir="cnn", # Learning rate - base_learning_rate=0.001, + base_learning_rate=0.1, weight_decay=0.9, + decay_steps=1000, ###### training options ########## convergence_threshold=0.01, @@ -76,6 +77,7 @@ class Trainer(object): self.base_learning_rate = base_learning_rate self.weight_decay = weight_decay + self.decay_steps = decay_steps self.iterations = iterations self.snapshot = snapshot @@ -101,6 +103,7 @@ class Trainer(object): self.thread_pool = None self.enqueue_op = None + self.global_step = None bob.core.log.set_verbosity_level(logger, verbosity_level) @@ -257,19 +260,20 @@ class Trainer(object): self.train_data_shuffler = train_data_shuffler # TODO: find an elegant way to provide this as a parameter of the trainer + self.global_step = tf.Variable(0, trainable=False) self.learning_rate = tf.train.exponential_decay( - self.base_learning_rate, # Learning rate - train_data_shuffler.batch_size, - train_data_shuffler.n_samples, - self.weight_decay # Decay step + learning_rate=self.base_learning_rate, # Learning rate + global_step=self.global_step, + decay_steps=self.decay_steps, + decay_rate=self.weight_decay, # Decay step + staircase=False ) - self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train") # Preparing the optimizer self.optimizer_class._learning_rate = self.learning_rate - #self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=tf.Variable(0)) - self.optimizer = self.optimizer_class.minimize(self.training_graph) + self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step) + # Train summary self.summaries_train = self.create_general_summary() diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py index 5ae2e8f073dea83c87037a6ba6fc6d9734a09a18..4e693142841e43720b159f61113b50fc238145bd 100644 --- a/bob/learn/tensorflow/trainers/TripletTrainer.py +++ b/bob/learn/tensorflow/trainers/TripletTrainer.py @@ -50,6 +50,7 @@ class TripletTrainer(Trainer): # Learning rate base_learning_rate=0.001, weight_decay=0.9, + decay_steps=1000, ###### training options ########## convergence_threshold=0.01, @@ -72,6 +73,7 @@ class TripletTrainer(Trainer): # Learning rate base_learning_rate=base_learning_rate, weight_decay=weight_decay, + decay_steps=decay_steps, ###### training options ########## convergence_threshold=convergence_threshold, @@ -188,6 +190,7 @@ class TripletTrainer(Trainer): self.within_class_graph_train, self.learning_rate, self.summaries_train], feed_dict=feed_dict) + print "LEARNING {0}".format(lr) logger.info("Loss training set step={0} = {1}".format(step, l)) self.train_summary_writter.add_summary(summary, step)