Added an option to sum the regularization losses in the general loss

parent bd3e1f09
......@@ -31,7 +31,7 @@ class MeanSoftMaxLoss(object):
Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
"""
def __init__(self, name="loss"):
def __init__(self, name="loss", add_regularization_losses=True):
"""
Constructor
......@@ -43,8 +43,15 @@ class MeanSoftMaxLoss(object):
"""
self.name = name
self.add_regularization_losses = add_regularization_losses
def __call__(self, graph, label):
return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=graph, labels=label), name=self.name)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=graph, labels=label), name=self.name)
if self.add_regularization_losses:
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
return tf.add_n([loss] + regularization_losses, name='total_loss')
else:
return loss
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment