Commit 4dc72f0a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Use tf.train.get_or_create_global_step (introduced in 1.2) instead of the contrib one

parent 82a80a54
Pipeline #14648 failed with stages
in 36 minutes and 16 seconds
......@@ -301,7 +301,7 @@ class LogitsCenterLoss(estimator.Estimator):
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.train.get_or_create_global_step()
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
centers)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
......
......@@ -43,7 +43,7 @@ class Siamese(estimator.Estimator):
return loss_set_of_ops(logits, labels)
extra_checkpoint = {"checkpoint_path":model_dir,
extra_checkpoint = {"checkpoint_path":model_dir,
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
......@@ -60,35 +60,35 @@ class Siamese(estimator.Estimator):
- tf.train.GradientDescentOptimizer
- tf.train.AdagradOptimizer
- ....
config:
loss_op:
Pointer to a function that computes the loss.
embedding_validation:
Run the validation using embeddings?? [default: False]
model_dir:
Model path
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function
Extra params for the model function
(please see https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict()
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
......@@ -99,18 +99,18 @@ class Siamese(estimator.Estimator):
model_dir="",
validation_batch_size=None,
params=None,
extra_checkpoint=None
extra_checkpoint=None
):
self.architecture = architecture
self.optimizer=optimizer
self.loss_op=loss_op
self.loss = None
self.extra_checkpoint = extra_checkpoint
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
if self.optimizer is None:
raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!")
......@@ -119,8 +119,8 @@ class Siamese(estimator.Estimator):
def _model_fn(features, labels, mode, params, config):
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == tf.estimator.ModeKeys.TRAIN:
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
......@@ -138,12 +138,12 @@ class Siamese(estimator.Estimator):
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_left, prelogits_right, labels)
# Configure the Training Op (for TRAIN mode)
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.train.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
......@@ -162,9 +162,9 @@ class Siamese(estimator.Estimator):
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
return tf.estimator.EstimatorSpec(mode=mode, loss=tf.reduce_mean(1), eval_metric_ops=eval_metric_ops)
super(Siamese, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
......
......@@ -51,18 +51,18 @@ class Triplet(estimator.Estimator):
- tf.train.GradientDescentOptimizer
- tf.train.AdagradOptimizer
- ....
config:
n_classes:
Number of classes of your problem. The logits will be appended in this class
loss_op:
Pointer to a function that computes the loss.
embedding_validation:
Run the validation using embeddings?? [default: False]
model_dir:
Model path
......@@ -97,7 +97,7 @@ class Triplet(estimator.Estimator):
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
if self.optimizer is None:
raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!")
......@@ -132,7 +132,7 @@ class Triplet(estimator.Estimator):
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_anchor, prelogits_positive, prelogits_negative)
# Configure the Training Op (for TRAIN mode)
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.train.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
......@@ -150,7 +150,7 @@ class Triplet(estimator.Estimator):
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
return tf.estimator.EstimatorSpec(mode=mode, loss=tf.reduce_mean(1), eval_metric_ops=eval_metric_ops)
super(Triplet, self).__init__(model_fn=_model_fn,
......
......@@ -95,22 +95,22 @@ class SiameseTrainer(Trainer):
self.graph = None
self.validation_graph = None
self.loss = None
self.validation_predictor = None
self.validation_predictor = None
self.optimizer_class = None
self.learning_rate = None
# Training variables used in the fit
self.optimizer = None
self.data_ph = None
self.label_ph = None
self.validation_data_ph = None
self.validation_label_ph = None
self.saver = None
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -140,7 +140,7 @@ class SiameseTrainer(Trainer):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.train.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......@@ -215,7 +215,7 @@ class SiameseTrainer(Trainer):
def fit(self, step):
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, bt_class, wt_class, lr, summary = self.session.run([
self.optimizer,
self.loss['loss'], self.loss['between_class'],
......
......@@ -100,24 +100,24 @@ class Trainer(object):
self.graph = None
self.validation_graph = None
self.prelogits = None
self.loss = None
self.validation_loss = None
self.validate_with_embeddings = validate_with_embeddings
self.validation_loss = None
self.validate_with_embeddings = validate_with_embeddings
self.optimizer_class = None
self.learning_rate = None
# Training variables used in the fit
self.optimizer = None
self.data_ph = None
self.label_ph = None
self.validation_data_ph = None
self.validation_label_ph = None
self.saver = None
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -125,10 +125,10 @@ class Trainer(object):
# Creating the session
self.session = Session.instance(new=True).session
self.from_scratch = True
def train(self):
"""
Train the network
Train the network
Here we basically have the loop for that takes your graph and do a sequence of session.run
"""
......@@ -146,7 +146,7 @@ class Trainer(object):
#if isinstance(train_data_shuffler, OnlineSampling):
# train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)
# Start a thread to enqueue data asynchronously, and hide I/O latency.
# Start a thread to enqueue data asynchronously, and hide I/O latency.
if self.train_data_shuffler.prefetch:
self.thread_pool = tf.train.Coordinator()
tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
......@@ -178,7 +178,7 @@ class Trainer(object):
self.compute_validation(step)
# Taking snapshot
if step % self.snapshot == 0:
if step % self.snapshot == 0:
logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.saver.save(self.session, path, global_step=step)
......@@ -189,8 +189,8 @@ class Trainer(object):
self.compute_validation_embeddings(step)
else:
self.compute_validation(step)
logger.info("Training finally finished")
self.train_summary_writter.close()
......@@ -205,7 +205,7 @@ class Trainer(object):
# now they should definetely stop
self.thread_pool.request_stop()
#if not isinstance(self.train_data_shuffler, TFRecord):
# self.thread_pool.join(threads)
# self.thread_pool.join(threads)
def create_network_from_scratch(self,
graph,
......@@ -221,7 +221,7 @@ class Trainer(object):
"""
Prepare all the tensorflow variables before training.
**Parameters**
graph: Input graph for training
......@@ -235,9 +235,9 @@ class Trainer(object):
# Getting the pointer to the placeholders
self.data_ph = self.train_data_shuffler("data", from_queue=True)
self.label_ph = self.train_data_shuffler("label", from_queue=True)
self.graph = graph
self.loss = loss
self.loss = loss
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
self.centers = None
......@@ -251,14 +251,14 @@ class Trainer(object):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.train.get_or_create_global_step()
# Preparing the optimizer
self.optimizer_class._learning_rate = self.learning_rate
self.optimizer = self.optimizer_class.minimize(self.loss, global_step=self.global_step)
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(),
self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(),
keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
self.summaries_train = self.create_general_summary(self.loss, self.graph, self.label_ph)
......@@ -283,15 +283,15 @@ class Trainer(object):
self.validation_graph = validation_graph
if self.validate_with_embeddings:
if self.validate_with_embeddings:
self.validation_loss = self.validation_graph
else:
else:
#self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
self.validation_loss = validation_loss
self.summaries_validation = self.create_general_summary(self.validation_loss, self.validation_graph, self.validation_label_ph)
tf.add_to_collection("summaries_validation", self.summaries_validation)
tf.add_to_collection("validation_graph", self.validation_graph)
tf.add_to_collection("validation_data_ph", self.validation_data_ph)
tf.add_to_collection("validation_label_ph", self.validation_label_ph)
......@@ -321,22 +321,22 @@ class Trainer(object):
else:
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
def load_variables_from_external_model(self, checkpoint_path, var_list):
"""
Load a set of variables from a given model and update them in the current one
** Parameters **
checkpoint_path:
Name of the tensorflow model to be loaded
var_list:
List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
"""
assert len(var_list)>0
tf_varlist = []
for v in var_list:
tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v)
......@@ -366,14 +366,14 @@ class Trainer(object):
# Loding other elements
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
self.summaries_train = tf.get_collection("summaries_train")[0]
self.summaries_train = tf.get_collection("summaries_train")[0]
self.global_step = tf.get_collection("global_step")[0]
self.from_scratch = False
if len(tf.get_collection("centers")) > 0:
self.centers = tf.get_collection("centers")[0]
self.prelogits = tf.get_collection("prelogits")[0]
# Loading the validation bits
if self.validation_data_shuffler is not None:
self.summaries_validation = tf.get_collection("summaries_validation")[0]
......@@ -414,14 +414,14 @@ class Trainer(object):
"""
if self.train_data_shuffler.prefetch:
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
if self.centers is None:
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
if self.centers is None:
_, l, lr, summary = self.session.run([self.optimizer, self.loss,
self.learning_rate, self.summaries_train])
else:
_, l, lr, summary, _ = self.session.run([self.optimizer, self.loss,
self.learning_rate, self.summaries_train, self.centers])
else:
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, lr, summary = self.session.run([self.optimizer, self.loss,
......@@ -451,7 +451,7 @@ class Trainer(object):
feed_dict=feed_dict)
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
self.validation_summary_writter.add_summary(summary, step)
self.validation_summary_writter.add_summary(summary, step)
def compute_validation_embeddings(self, step):
"""
......@@ -463,19 +463,19 @@ class Trainer(object):
step: Iteration number
"""
if self.validation_data_shuffler.prefetch:
embedding, labels = self.session.run([self.validation_loss, self.validation_label_ph])
else:
feed_dict = self.get_feed_dict(self.validation_data_shuffler)
embedding, labels = self.session.run([self.validation_loss, self.validation_label_ph],
feed_dict=feed_dict)
accuracy = compute_embedding_accuracy(embedding, labels)
summary = summary_pb2.Summary.Value(tag="accuracy", simple_value=accuracy)
logger.info("VALIDATION Accuracy set step={0} = {1}".format(step, accuracy))
self.validation_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)
self.validation_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)
def create_general_summary(self, average_loss, output, label):
......@@ -487,16 +487,16 @@ class Trainer(object):
#for var in tf.trainable_variables():
#for var in tf.global_variables():
# tf.summary.histogram(var.op.name, var)
# Train summary
tf.summary.scalar('loss', average_loss)
tf.summary.scalar('lr', self.learning_rate)
tf.summary.scalar('lr', self.learning_rate)
# Computing accuracy
correct_prediction = tf.equal(tf.argmax(output, 1), label)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
tf.summary.scalar('accuracy', accuracy)
return tf.summary.merge_all()
def start_thread(self):
......
......@@ -96,23 +96,23 @@ class TripletTrainer(Trainer):
self.graph = None
self.validation_graph = None
self.loss = None
self.validation_predictor = None
self.validation_predictor = None
self.optimizer_class = None
self.learning_rate = None
# Training variables used in the fit
self.optimizer = None
self.data_ph = None
self.label_ph = None
self.validation_data_ph = None
self.validation_label_ph = None
self.saver = None
bob.core.log.set_verbosity_level(logger, verbosity_level)
......@@ -141,7 +141,7 @@ class TripletTrainer(Trainer):
self.optimizer_class = optimizer
self.learning_rate = learning_rate
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.train.get_or_create_global_step()
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
......
......@@ -18,7 +18,7 @@ def exponential_decay(base_learning_rate=0.05,
staircase: Boolean. It True decay the learning rate at discrete intervals
"""
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.train.get_or_create_global_step()
return tf.train.exponential_decay(learning_rate=base_learning_rate,
global_step=global_step,
decay_steps=decay_steps,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment