Skip to content
Snippets Groups Projects
Commit 94b53903 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed issue #7

Conflicts:
	bob/learn/tensorflow/script/train_mobio.py
parents a8e61de8 5143af4e
Branches
Tags v0.0.1b10
No related merge requests found
......@@ -101,7 +101,7 @@ def main():
#optimizer = optimizer,
trainer = TripletTrainer(architecture=architecture, loss=loss,
iterations=ITERATIONS,
base_learning_rate=0.0001,
base_learning_rate=0.1,
prefetch=False,
temp_dir="./LOGS_MOBIO/triplet-cnn")
......
......@@ -50,6 +50,7 @@ class SiameseTrainer(Trainer):
# Learning rate
base_learning_rate=0.001,
weight_decay=0.9,
decay_steps=1000,
###### training options ##########
convergence_threshold=0.01,
......@@ -72,6 +73,7 @@ class SiameseTrainer(Trainer):
# Learning rate
base_learning_rate=base_learning_rate,
weight_decay=weight_decay,
decay_steps=decay_steps,
###### training options ##########
convergence_threshold=convergence_threshold,
......
......@@ -51,8 +51,9 @@ class Trainer(object):
temp_dir="cnn",
# Learning rate
base_learning_rate=0.001,
base_learning_rate=0.1,
weight_decay=0.9,
decay_steps=1000,
###### training options ##########
convergence_threshold=0.01,
......@@ -76,6 +77,7 @@ class Trainer(object):
self.base_learning_rate = base_learning_rate
self.weight_decay = weight_decay
self.decay_steps = decay_steps
self.iterations = iterations
self.snapshot = snapshot
......@@ -101,6 +103,7 @@ class Trainer(object):
self.thread_pool = None
self.enqueue_op = None
self.global_step = None
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -257,19 +260,20 @@ class Trainer(object):
self.train_data_shuffler = train_data_shuffler
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False)
self.learning_rate = tf.train.exponential_decay(
self.base_learning_rate, # Learning rate
train_data_shuffler.batch_size,
train_data_shuffler.n_samples,
self.weight_decay # Decay step
learning_rate=self.base_learning_rate, # Learning rate
global_step=self.global_step,
decay_steps=self.decay_steps,
decay_rate=self.weight_decay, # Decay step
staircase=False
)
self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
# Preparing the optimizer
self.optimizer_class._learning_rate = self.learning_rate
#self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=tf.Variable(0))
self.optimizer = self.optimizer_class.minimize(self.training_graph)
self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
# Train summary
self.summaries_train = self.create_general_summary()
......
......@@ -50,6 +50,7 @@ class TripletTrainer(Trainer):
# Learning rate
base_learning_rate=0.001,
weight_decay=0.9,
decay_steps=1000,
###### training options ##########
convergence_threshold=0.01,
......@@ -72,6 +73,7 @@ class TripletTrainer(Trainer):
# Learning rate
base_learning_rate=base_learning_rate,
weight_decay=weight_decay,
decay_steps=decay_steps,
###### training options ##########
convergence_threshold=convergence_threshold,
......@@ -188,6 +190,7 @@ class TripletTrainer(Trainer):
self.within_class_graph_train,
self.learning_rate, self.summaries_train],
feed_dict=feed_dict)
print "LEARNING {0}".format(lr)
logger.info("Loss training set step={0} = {1}".format(step, l))
self.train_summary_writter.add_summary(summary, step)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment