Commit e43aec3d authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Updates to the logits estimator

parent d9b4b431
Pipeline #20949 failed with stage
in 27 minutes and 27 seconds
......@@ -99,7 +99,8 @@ class Logits(estimator.Estimator):
validation_batch_size=None,
params=None,
extra_checkpoint=None,
apply_moving_averages=True):
apply_moving_averages=True,
add_histograms=None):
self.architecture = architecture
self.optimizer = optimizer
......@@ -114,61 +115,13 @@ class Logits(estimator.Estimator):
check_features(features)
data = features['data']
key = features['key']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
# Building the training graph
# Checking if we have some variables/scope that we may want to shut down
# Checking if we have some variables/scope that we may want to shut
# down
trainable_variables = get_trainable_variables(
self.extra_checkpoint)
self.extra_checkpoint, mode=mode)
prelogits = self.architecture(
data, mode=mode,
trainable_variables=trainable_variables)[0]
logits = append_logits(prelogits, n_classes)
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(
self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.train.get_or_create_global_step()
# Compute the moving average of all individual losses and the total loss.
if apply_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(
0.9999, global_step)
variable_averages_op = variable_averages.apply(
tf.trainable_variables())
else:
variable_averages_op = tf.no_op(name='noop')
with tf.control_dependencies([variable_averages_op]):
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits=logits, labels=labels)
# Compute the moving average of all individual losses and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(
0.9, name='avg')
loss_averages_op = loss_averages.apply(
tf.get_collection(tf.GraphKeys.LOSSES))
for l in tf.get_collection(tf.GraphKeys.LOSSES):
tf.summary.scalar(l.op.name + "_averaged",
loss_averages.average(l))
global_step = tf.train.get_or_create_global_step()
train_op = tf.group(
self.optimizer.minimize(
self.loss, global_step=global_step),
variable_averages_op, loss_averages_op)
return tf.estimator.EstimatorSpec(
mode=mode, loss=self.loss, train_op=train_op)
# Building the training graph for PREDICTION OR VALIDATION
prelogits = self.architecture(data, mode=mode)[0]
data, mode=mode, trainable_variables=trainable_variables)[0]
logits = append_logits(prelogits, n_classes)
if self.embedding_validation:
......@@ -179,13 +132,13 @@ class Logits(estimator.Estimator):
"key": key,
}
else:
probabilities = tf.nn.softmax(logits, name="softmax_tensor")
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT
# and by the `logging_hook`.
"probabilities": probabilities,
"probabilities": tf.nn.softmax(
logits, name="softmax_tensor"),
"key": key,
}
......@@ -193,31 +146,88 @@ class Logits(estimator.Estimator):
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions)
# IF Validation
self.loss = self.loss_op(logits=logits, labels=labels)
if self.embedding_validation:
predictions_op = predict_using_tensors(
predictions["embeddings"],
labels,
num=validation_batch_size)
eval_metric_ops = {
"accuracy":
tf.metrics.accuracy(
else:
predictions_op = predictions["classes"]
accuracy = tf.metrics.accuracy(
labels=labels, predictions=predictions_op)
}
metrics = {'accuracy': accuracy}
if mode == tf.estimator.ModeKeys.EVAL:
self.loss = self.loss_op(logits=logits, labels=labels)
return tf.estimator.EstimatorSpec(
mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
mode=mode,
predictions=predictions,
loss=self.loss,
train_op=None,
eval_metric_ops=metrics)
# restore the model from an extra_checkpoint
if extra_checkpoint is not None:
tf.train.init_from_checkpoint(
ckpt_dir_or_file=extra_checkpoint["checkpoint_path"],
assignment_map=extra_checkpoint["scopes"],
)
global_step = tf.train.get_or_create_global_step()
# Compute the moving average of all individual losses and the
# total loss.
if apply_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(
0.9999, global_step)
variable_averages_op = variable_averages.apply(
tf.trainable_variables())
else:
# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
"accuracy":
tf.metrics.accuracy(
labels=labels, predictions=predictions["classes"])
}
variable_averages_op = tf.no_op(name='noop')
# Some layer like tf.layers.batch_norm need this:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies([variable_averages_op] + update_ops):
# Calculate Loss
self.loss = self.loss_op(logits=logits, labels=labels)
# Compute the moving average of all individual losses
# and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(
0.9, name='avg')
loss_averages_op = loss_averages.apply(
tf.get_collection(tf.GraphKeys.LOSSES))
train_op = tf.group(
self.optimizer.minimize(
self.loss, global_step=global_step),
variable_averages_op, loss_averages_op)
# Log accuracy and loss
with tf.name_scope('train_metrics'):
tf.summary.scalar('accuracy', accuracy[1])
tf.summary.scalar('loss', self.loss)
for l in tf.get_collection(tf.GraphKeys.LOSSES):
tf.summary.scalar(l.op.name + "_averaged",
loss_averages.average(l))
# add histograms summaries
if add_histograms == 'all':
for v in tf.all_variables():
tf.summary.histogram(v.name, v)
elif add_histograms == 'train':
for v in tf.trainable_variables():
tf.summary.histogram(v.name, v)
return tf.estimator.EstimatorSpec(
mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
mode=mode,
predictions=predictions,
loss=self.loss,
train_op=train_op,
eval_metric_ops=metrics)
super(Logits, self).__init__(
model_fn=_model_fn,
......
......@@ -13,7 +13,8 @@ def check_features(features):
return True
def get_trainable_variables(extra_checkpoint):
def get_trainable_variables(extra_checkpoint,
mode=tf.estimator.ModeKeys.TRAIN):
"""
Given the extra_checkpoint dictionary provided to the estimator,
extract the content of "trainable_variables" e.
......@@ -24,8 +25,11 @@ def get_trainable_variables(extra_checkpoint):
Parameters
----------
extra_checkpoint: dict
extra_checkpoint : dict
The `extra_checkpoint dictionary provided to the estimator
mode
The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is
returned.
Returns
-------
......@@ -33,6 +37,8 @@ def get_trainable_variables(extra_checkpoint):
otherwise returns the content of `extra_checkpoint
"""
if mode != tf.estimator.ModeKeys.TRAIN:
return None
# If you don't set anything, everything is trainable
if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment