Skip to content
Snippets Groups Projects

Updates

Merged Tiago de Freitas Pereira requested to merge updates into master
3 files
+ 51
53
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -61,7 +61,7 @@ class MeanSoftMaxLossCenterLoss(object):
Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
"""
def __init__(self, name="loss", add_regularization_losses=True, alpha=0.9, factor=0.01, n_classes=10):
def __init__(self, name="loss", alpha=0.9, factor=0.01, n_classes=10):
"""
Constructor
@@ -73,46 +73,36 @@ class MeanSoftMaxLossCenterLoss(object):
"""
self.name = name
self.add_regularization_losses = add_regularization_losses
self.n_classes = n_classes
self.alpha = alpha
self.factor = factor
def append_center_loss(self, features, label):
nrof_features = features.get_shape()[1]
centers = tf.get_variable('centers', [self.n_classes, nrof_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = (1 - self.alpha) * (centers_batch - features)
centers = tf.scatter_sub(centers, label, diff)
loss = tf.reduce_mean(tf.square(features - centers_batch))
return loss
def __call__(self, logits_prelogits, label):
#TODO: Test the dictionary
logits = logits_prelogits['logits']
def __call__(self, logits, prelogits, label):
# Cross entropy
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label), name=self.name)
with tf.variable_scope('cross_entropy_loss'):
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label), name=self.name)
# Appending center loss
prelogits = logits_prelogits['prelogits']
center_loss = self.append_center_loss(prelogits, label)
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
# Appending center loss
with tf.variable_scope('center_loss'):
n_features = prelogits.get_shape()[1]
centers = tf.get_variable('centers', [self.n_classes, n_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = (1 - self.alpha) * (centers_batch - prelogits)
centers = tf.scatter_sub(centers, label, diff)
center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
# Adding the regularizers in the loss
if self.add_regularization_losses:
with tf.variable_scope('total_loss'):
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([loss] + regularization_losses, name='total_loss')
total_loss = tf.add_n([loss] + regularization_losses, name='total_loss')
return loss
return total_loss, centers
Loading