From ef026be17a6c75ef84046f335776d762556bbb20 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira Date: Tue, 27 Mar 2018 17:38:46 +0200 Subject: [PATCH] Organized losses --- bob/learn/tensorflow/loss/ContrastiveLoss.py | 17 ++++++++--- bob/learn/tensorflow/loss/TripletLoss.py | 28 ++++++++++++++----- .../tensorflow/network/InceptionResnetV2.py | 2 +- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/bob/learn/tensorflow/loss/ContrastiveLoss.py b/bob/learn/tensorflow/loss/ContrastiveLoss.py index f72cbd8..edf6c09 100644 --- a/bob/learn/tensorflow/loss/ContrastiveLoss.py +++ b/bob/learn/tensorflow/loss/ContrastiveLoss.py @@ -52,19 +52,28 @@ def contrastive_loss(left_embedding, within_class = tf.multiply(one - labels, tf.square(d)) # (1-Y)*(d^2) within_class_loss = tf.reduce_mean( - within_class, name=tf.GraphKeys.LOSSES) + within_class, name="within_class") + tf.add_to_collection(tf.GraphKeys.LOSSES, within_class_loss) with tf.name_scope("between_class"): max_part = tf.square(tf.maximum(contrastive_margin - d, 0)) between_class = tf.multiply( labels, max_part) # (Y) * max((margin - d)^2, 0) between_class_loss = tf.reduce_mean( - between_class, name=tf.GraphKeys.LOSSES) + between_class, name="between_class") + tf.add_to_collection(tf.GraphKeys.LOSSES, between_class_loss) with tf.name_scope("total_loss"): loss = 0.5 * (within_class + between_class) - loss = tf.reduce_mean(loss, name=tf.GraphKeys.LOSSES) - + loss = tf.reduce_mean(loss, name="total_loss_raw") + tf.summary.scalar('loss_raw', loss) + tf.add_to_collection(tf.GraphKeys.LOSSES, loss) + + ## Appending the regularization loss + #regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + #loss = tf.add_n([loss] + regularization_losses, name="total_loss") + + tf.summary.scalar('loss', loss) tf.summary.scalar('between_class', between_class_loss) tf.summary.scalar('within_class', within_class_loss) diff --git a/bob/learn/tensorflow/loss/TripletLoss.py b/bob/learn/tensorflow/loss/TripletLoss.py index fe59cbb..8aa0792 100644 --- a/bob/learn/tensorflow/loss/TripletLoss.py +++ b/bob/learn/tensorflow/loss/TripletLoss.py @@ -53,14 +53,28 @@ def triplet_loss(anchor_embedding, tf.square(tf.subtract(anchor_embedding, negative_embedding)), 1) basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin) - loss = tf.reduce_mean( - tf.maximum(basic_loss, 0.0), 0, name=tf.GraphKeys.LOSSES) - between_class_loss = tf.reduce_mean(d_negative) - within_class_loss = tf.reduce_mean(d_positive) - tf.summary.scalar('loss', loss) - tf.summary.scalar('between_class', between_class_loss) - tf.summary.scalar('within_class', within_class_loss) + with tf.name_scope("TripletLoss"): + # Between + between_class_loss = tf.reduce_mean(d_negative) + tf.summary.scalar('between_class', between_class_loss) + tf.add_to_collection(tf.GraphKeys.LOSSES, between_class_loss) + + # Within + within_class_loss = tf.reduce_mean(d_positive) + tf.summary.scalar('within_class', within_class_loss) + tf.add_to_collection(tf.GraphKeys.LOSSES, within_class_loss) + + # Total loss + loss = tf.reduce_mean( + tf.maximum(basic_loss, 0.0), 0, name="total_loss") + tf.add_to_collection(tf.GraphKeys.LOSSES, loss) + tf.summary.scalar('loss_raw', loss) + + # Appending the regularization loss + #regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + #loss = tf.add_n([loss] + regularization_losses, name="total_loss") + #tf.summary.scalar('loss', loss) return loss diff --git a/bob/learn/tensorflow/network/InceptionResnetV2.py b/bob/learn/tensorflow/network/InceptionResnetV2.py index b52fcc4..3bb0f9a 100644 --- a/bob/learn/tensorflow/network/InceptionResnetV2.py +++ b/bob/learn/tensorflow/network/InceptionResnetV2.py @@ -252,7 +252,7 @@ def inception_resnet_v2_batch_norm(inputs, # Moving averages ends up in the trainable variables collection 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], } - + weight_decay = 5e-5 with slim.arg_scope( [slim.conv2d, slim.fully_connected], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), -- GitLab