Skip to content
Snippets Groups Projects

Updates

Merged Tiago de Freitas Pereira requested to merge updates into master
1 file
+ 10
3
Compare changes
  • Side-by-side
  • Inline
@@ -31,7 +31,7 @@ class MeanSoftMaxLoss(object):
Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
"""
def __init__(self, name="loss"):
def __init__(self, name="loss", add_regularization_losses=True):
"""
Constructor
@@ -43,8 +43,15 @@ class MeanSoftMaxLoss(object):
"""
self.name = name
self.add_regularization_losses = add_regularization_losses
def __call__(self, graph, label):
return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=graph, labels=label), name=self.name)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=graph, labels=label), name=self.name)
if self.add_regularization_losses:
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
return tf.add_n([loss] + regularization_losses, name='total_loss')
else:
return loss
Loading