Skip to content
Snippets Groups Projects
Commit 8fd43642 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Added an option to sum the regularization losses in the general loss

parent bd3e1f09
No related branches found
No related tags found
1 merge request!17Updates
......@@ -31,7 +31,7 @@ class MeanSoftMaxLoss(object):
Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
"""
def __init__(self, name="loss"):
def __init__(self, name="loss", add_regularization_losses=True):
"""
Constructor
......@@ -43,8 +43,15 @@ class MeanSoftMaxLoss(object):
"""
self.name = name
self.add_regularization_losses = add_regularization_losses
def __call__(self, graph, label):
return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=graph, labels=label), name=self.name)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=graph, labels=label), name=self.name)
if self.add_regularization_losses:
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
return tf.add_n([loss] + regularization_losses, name='total_loss')
else:
return loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment