Commit 1d4fc91a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Pass logits and labels to losses using kwargs

parent 90b4835f
Pipeline #20863 canceled with stage
in 1 minute and 3 seconds
......@@ -80,12 +80,12 @@ class Logits(estimator.Estimator):
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"trainable_variables": [<LIST OF VARIABLES OR SCOPES THAT YOU WANT TO RETRAIN>]
}
apply_moving_averages: bool
Apply exponential moving average in the training variables and in the loss.
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
By default the decay for the variable averages is 0.9999 and for the loss is 0.9
"""
def __init__(self,
......@@ -144,7 +144,7 @@ class Logits(estimator.Estimator):
with tf.control_dependencies([variable_averages_op]):
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
self.loss = self.loss_op(logits=logits, labels=labels)
# Compute the moving average of all individual losses and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
......@@ -187,7 +187,7 @@ class Logits(estimator.Estimator):
mode=mode, predictions=predictions)
# IF Validation
self.loss = self.loss_op(logits, labels)
self.loss = self.loss_op(logits=logits, labels=labels)
if self.embedding_validation:
predictions_op = predict_using_tensors(
......@@ -266,7 +266,7 @@ class LogitsCenterLoss(estimator.Estimator):
params:
Extra params for the model function (please see
https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict
In case you want to use other model to initialize some variables.
This argument should be in the following format
......@@ -275,13 +275,13 @@ class LogitsCenterLoss(estimator.Estimator):
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"trainable_variables": [<LIST OF VARIABLES OR SCOPES THAT YOU WANT TO TRAIN>]
}
apply_moving_averages: bool
Apply exponential moving average in the training variables and in the loss.
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
By default the decay for the variable averages is 0.9999 and for the loss is 0.9
"""
......@@ -365,7 +365,7 @@ class LogitsCenterLoss(estimator.Estimator):
# Compute the moving average of all individual losses and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
loss_averages_op = loss_averages.apply(tf.get_collection(tf.GraphKeys.LOSSES))
for l in tf.get_collection(tf.GraphKeys.LOSSES):
tf.summary.scalar(l.op.name, loss_averages.average(l))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment