From 35180690b232bcb116db9ff949a70ce391b1e3ca Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 5 Oct 2017 15:18:13 +0200
Subject: [PATCH] Implemented center loss

---
 bob/learn/tensorflow/loss/BaseLoss.py         | 54 ++++++++-----------
 .../tensorflow/trainers/SiameseTrainer.py     |  1 -
 bob/learn/tensorflow/trainers/Trainer.py      | 49 ++++++++++-------
 3 files changed, 51 insertions(+), 53 deletions(-)

diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py
index e71f0b73..8c27710d 100644
--- a/bob/learn/tensorflow/loss/BaseLoss.py
+++ b/bob/learn/tensorflow/loss/BaseLoss.py
@@ -61,7 +61,7 @@ class MeanSoftMaxLossCenterLoss(object):
     Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
     """
 
-    def __init__(self, name="loss", add_regularization_losses=True, alpha=0.9, factor=0.01, n_classes=10):
+    def __init__(self, name="loss", alpha=0.9, factor=0.01, n_classes=10):
         """
         Constructor
         
@@ -73,46 +73,36 @@ class MeanSoftMaxLossCenterLoss(object):
         """
     
         self.name = name
-        self.add_regularization_losses = add_regularization_losses
 
         self.n_classes = n_classes
         self.alpha = alpha
         self.factor = factor
 
 
-    def append_center_loss(self, features, label):
-        nrof_features = features.get_shape()[1]
-        
-        centers = tf.get_variable('centers', [self.n_classes, nrof_features], dtype=tf.float32,
-            initializer=tf.constant_initializer(0), trainable=False)
-            
-        label = tf.reshape(label, [-1])
-        centers_batch = tf.gather(centers, label)
-        diff = (1 - self.alpha) * (centers_batch - features)
-        centers = tf.scatter_sub(centers, label, diff)
-        loss = tf.reduce_mean(tf.square(features - centers_batch))
-        
-        return loss
-
-
-    def __call__(self, logits_prelogits, label):
-    
-        #TODO: Test the dictionary
-    
-        logits = logits_prelogits['logits']
-    
+    def __call__(self, logits, prelogits, label):           
         # Cross entropy
-        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
-                                          logits=logits, labels=label), name=self.name)
+        with tf.variable_scope('cross_entropy_loss'):
+            loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
+                                              logits=logits, labels=label), name=self.name)
 
-        # Appending center loss
-        prelogits = logits_prelogits['prelogits']
-        center_loss = self.append_center_loss(prelogits, label)
-        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
+        # Appending center loss        
+        with tf.variable_scope('center_loss'):
+            n_features = prelogits.get_shape()[1]
+            
+            centers = tf.get_variable('centers', [self.n_classes, n_features], dtype=tf.float32,
+                initializer=tf.constant_initializer(0), trainable=False)
+                
+            label = tf.reshape(label, [-1])
+            centers_batch = tf.gather(centers, label)
+            diff = (1 - self.alpha) * (centers_batch - prelogits)
+            centers = tf.scatter_sub(centers, label, diff)
+            center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))       
+            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
     
         # Adding the regularizers in the loss
-        if self.add_regularization_losses:
+        with tf.variable_scope('total_loss'):
             regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
-            loss =  tf.add_n([loss] + regularization_losses, name='total_loss')
+            total_loss =  tf.add_n([loss] + regularization_losses, name='total_loss')
             
-        return loss            
+        return total_loss, centers
+
diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py
index be16b4fe..300b8a06 100644
--- a/bob/learn/tensorflow/trainers/SiameseTrainer.py
+++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py
@@ -219,7 +219,6 @@ class SiameseTrainer(Trainer):
         return feed_dict
 
     def fit(self, step):
-
         feed_dict = self.get_feed_dict(self.train_data_shuffler)
         _, l, bt_class, wt_class, lr, summary = self.session.run([
                                                 self.optimizer,
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index 6631ca7b..8b7ebc2e 100644
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -177,7 +177,7 @@ class Trainer(object):
                     self.compute_validation(step)
 
             # Taking snapshot
-            if step % self.snapshot == 0:
+            if step % self.snapshot == 0:            
                 logger.info("Taking snapshot")
                 path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                 self.saver.save(self.session, path, global_step=step)
@@ -214,6 +214,7 @@ class Trainer(object):
 
                                     # Learning rate
                                     learning_rate=None,
+                                    prelogits=None
                                     ):
 
         """
@@ -229,7 +230,6 @@ class Trainer(object):
 
             learning_rate: Learning rate
         """
-
         # Getting the pointer to the placeholders
         self.data_ph = self.train_data_shuffler("data", from_queue=True)
         self.label_ph = self.train_data_shuffler("label", from_queue=True)
@@ -237,8 +237,13 @@ class Trainer(object):
         self.graph = graph
         self.loss = loss        
 
-        # Attaching the loss in the graph
-        self.predictor = self.loss(self.graph, self.label_ph)
+        # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
+        self.centers = None
+        if prelogits is not None:
+            tf.add_to_collection("prelogits", prelogits)
+            self.predictor, self.centers = self.loss(self.graph, prelogits, self.label_ph)
+        else:
+            self.predictor = self.loss(self.graph, self.label_ph)
         
         self.optimizer_class = optimizer
         self.learning_rate = learning_rate
@@ -257,11 +262,8 @@ class Trainer(object):
         # SAving some variables
         tf.add_to_collection("global_step", self.global_step)
 
-        if isinstance(self.graph, dict):
-            tf.add_to_collection("graph", self.graph['logits'])
-            tf.add_to_collection("prelogits", self.graph['prelogits'])
-        else:
-            tf.add_to_collection("graph", self.graph)
+            
+        tf.add_to_collection("graph", self.graph)
         
         tf.add_to_collection("predictor", self.predictor)
 
@@ -273,6 +275,10 @@ class Trainer(object):
 
         tf.add_to_collection("summaries_train", self.summaries_train)
 
+        # Appending histograms for each trainable variables
+        for var in tf.trainable_variables():
+            tf.summary.histogram(var.op.name, var)
+
         # Same business with the validation
         if self.validation_data_shuffler is not None:
             self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
@@ -280,9 +286,9 @@ class Trainer(object):
 
             self.validation_graph = validation_graph
 
-            if self.validate_with_embeddings:
+            if self.validate_with_embeddings:            
                 self.validation_predictor = self.validation_graph
-            else:
+            else:            
                 self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
 
             self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph)
@@ -318,13 +324,13 @@ class Trainer(object):
             self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
             self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
             
-    def load_variables_from_external_model(self, file_name, var_list):
+    def load_variables_from_external_model(self, checkpoint_path, var_list):
         """
         Load a set of variables from a given model and update them in the current one
         
         ** Parameters **
         
-          file_name:
+          checkpoint_path:
             Name of the tensorflow model to be loaded
           var_list:
             List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
@@ -338,7 +344,7 @@ class Trainer(object):
             tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v)
 
         saver = tf.train.Saver(tf_varlist)
-        saver.restore(self.session, file_name)
+        saver.restore(self.session, tf.train.latest_checkpoint(checkpoint_path))
 
     def create_network_from_file(self, file_name, clear_devices=True):
         """
@@ -406,8 +412,14 @@ class Trainer(object):
         """
 
         if self.train_data_shuffler.prefetch:
-            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
-                                                  self.learning_rate, self.summaries_train])
+            # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT        
+            if self.centers is None:            
+                _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
+                                                      self.learning_rate, self.summaries_train])
+            else:
+                _, l, lr, summary, _ = self.session.run([self.optimizer, self.predictor,
+                                                      self.learning_rate, self.summaries_train, self.centers])
+            
         else:
             feed_dict = self.get_feed_dict(self.train_data_shuffler)
             _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
@@ -473,10 +485,7 @@ class Trainer(object):
         tf.summary.scalar('lr', self.learning_rate)        
 
         # Computing accuracy
-        if isinstance(output, dict):
-            correct_prediction = tf.equal(tf.argmax(output['logits'], 1), label)
-        else:
-            correct_prediction = tf.equal(tf.argmax(output, 1), label)
+        correct_prediction = tf.equal(tf.argmax(output, 1), label)
         
         accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
         tf.summary.scalar('accuracy', accuracy)        
-- 
GitLab