Skip to content
Snippets Groups Projects

Pass logits and labels to losses using kwargs

Merged Amir MOHAMMADI requested to merge loss into master
All threads resolved!
1 file
+ 9
9
Compare changes
  • Side-by-side
  • Inline
@@ -80,12 +80,12 @@ class Logits(estimator.Estimator):
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"trainable_variables": [<LIST OF VARIABLES OR SCOPES THAT YOU WANT TO RETRAIN>]
}
apply_moving_averages: bool
Apply exponential moving average in the training variables and in the loss.
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
By default the decay for the variable averages is 0.9999 and for the loss is 0.9
"""
def __init__(self,
@@ -144,7 +144,7 @@ class Logits(estimator.Estimator):
with tf.control_dependencies([variable_averages_op]):
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
self.loss = self.loss_op(logits=logits, labels=labels)
# Compute the moving average of all individual losses and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
@@ -187,7 +187,7 @@ class Logits(estimator.Estimator):
mode=mode, predictions=predictions)
# IF Validation
self.loss = self.loss_op(logits, labels)
self.loss = self.loss_op(logits=logits, labels=labels)
if self.embedding_validation:
predictions_op = predict_using_tensors(
@@ -266,7 +266,7 @@ class LogitsCenterLoss(estimator.Estimator):
params:
Extra params for the model function (please see
https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict
In case you want to use other model to initialize some variables.
This argument should be in the following format
@@ -275,13 +275,13 @@ class LogitsCenterLoss(estimator.Estimator):
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"trainable_variables": [<LIST OF VARIABLES OR SCOPES THAT YOU WANT TO TRAIN>]
}
apply_moving_averages: bool
Apply exponential moving average in the training variables and in the loss.
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
By default the decay for the variable averages is 0.9999 and for the loss is 0.9
"""
@@ -365,7 +365,7 @@ class LogitsCenterLoss(estimator.Estimator):
# Compute the moving average of all individual losses and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
loss_averages_op = loss_averages.apply(tf.get_collection(tf.GraphKeys.LOSSES))
for l in tf.get_collection(tf.GraphKeys.LOSSES):
tf.summary.scalar(l.op.name, loss_averages.average(l))
Loading