BaseLoss.py 2.99 KB
Newer Older
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
1 2 3 4 5
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>

import logging
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
6
import tensorflow as tf
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
7 8
logger = logging.getLogger("bob.learn.tensorflow")

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
9 10
slim = tf.contrib.slim

Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
11

12
def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True):
13
    """
14 15 16 17
    Simple CrossEntropy loss.
    Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
    
    **Parameters**
18 19
      logits:
      labels:
20 21
      add_regularization_losses: Regulize the loss???
    
22 23
    """

24
    with tf.variable_scope('cross_entropy_loss'):
25

26
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
27 28 29
                                          logits=logits, labels=labels), name=tf.GraphKeys.LOSSES)
        
        if add_regularization_losses:
30
            regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
31
            return tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES)
32 33
        else:
            return loss
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
34
            
35

36
def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0.9, factor=0.01):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
37
    """
38 39 40 41
    Implementation of the CrossEntropy + Center Loss from the paper
    "A Discriminative Feature Learning Approach for Deep Face Recognition"(http://ydwen.github.io/papers/WenECCV16.pdf)
    
    **Parameters**
42 43 44 45
      logits:
      prelogits:
      labels:
      n_classes: Number of classes of your task
46 47
      alpha: Alpha factor ((1-alpha)*centers-prelogits)
      factor: Weight factor of the center loss
48

49
    """
50 51 52 53
    # Cross entropy
    with tf.variable_scope('cross_entropy_loss'):
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
                                          logits=logits, labels=labels), name=tf.GraphKeys.LOSSES)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54 55
                                          
        tf.summary.scalar('cross_entropy_loss', loss)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
56

57 58 59 60 61 62 63
    # Appending center loss        
    with tf.variable_scope('center_loss'):
        n_features = prelogits.get_shape()[1]
        
        centers = tf.get_variable('centers', [n_classes, n_features], dtype=tf.float32,
            initializer=tf.constant_initializer(0), trainable=False)
            
64
        #label = tf.reshape(labels, [-1])
65 66 67 68 69
        centers_batch = tf.gather(centers, labels)
        diff = (1 - alpha) * (centers_batch - prelogits)
        centers = tf.scatter_sub(centers, labels, diff)
        center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))       
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * factor)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70
        tf.summary.scalar('center_loss', center_loss)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
71

72 73 74
    # Adding the regularizers in the loss
    with tf.variable_scope('total_loss'):
        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
75
        total_loss = tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
76
        tf.summary.scalar('total_loss', total_loss)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
77

78 79 80
    loss = dict()
    loss['loss'] = total_loss
    loss['centers'] = centers
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81

82
    return loss
83