From ad6a9bba6c6243c7d5a88e0574a65b458ff06717 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Wed, 11 Oct 2017 17:58:04 +0200
Subject: [PATCH] Replaced some classes to functions #37 . Still need to update
 some tests

---
 bob/learn/tensorflow/loss/BaseLoss.py         | 101 ++++++--------
 bob/learn/tensorflow/loss/ContrastiveLoss.py  |  45 +++----
 .../tensorflow/loss/TripletAverageLoss.py     |  60 ---------
 .../tensorflow/loss/TripletFisherLoss.py      |  62 ---------
 bob/learn/tensorflow/loss/TripletLoss.py      | 122 ++++++++++++++---
 bob/learn/tensorflow/loss/__init__.py         |  20 +--
 bob/learn/tensorflow/network/Chopra.py        | 123 +++++++-----------
 bob/learn/tensorflow/network/__init__.py      |   2 +-
 bob/learn/tensorflow/network/utils.py         |  11 +-
 .../tensorflow/script/lfw_db_to_tfrecords.py  |  10 +-
 bob/learn/tensorflow/script/train.py          |   8 +-
 bob/learn/tensorflow/test/test_cnn.py         |  79 ++++++-----
 .../tensorflow/test/test_cnn_other_losses.py  |  50 ++-----
 bob/learn/tensorflow/test/test_cnn_scratch.py |   2 +-
 .../tensorflow/trainers/SiameseTrainer.py     |  33 ++---
 bob/learn/tensorflow/trainers/Trainer.py      |  26 ++--
 .../tensorflow/trainers/TripletTrainer.py     |  30 ++---
 17 files changed, 315 insertions(+), 469 deletions(-)
 delete mode 100755 bob/learn/tensorflow/loss/TripletAverageLoss.py
 delete mode 100755 bob/learn/tensorflow/loss/TripletFisherLoss.py

diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py
index 679dd12b..cbfeadfa 100755
--- a/bob/learn/tensorflow/loss/BaseLoss.py
+++ b/bob/learn/tensorflow/loss/BaseLoss.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Tue 09 Aug 2016 16:38 CEST
 
 import logging
 import tensorflow as tf
@@ -10,92 +9,70 @@ logger = logging.getLogger("bob.learn.tensorflow")
 slim = tf.contrib.slim
 
 
-class BaseLoss(object):
-    """
-    Base loss function.
-    Stupid class. Don't know why I did that.
-    """
-
-    def __init__(self, loss, operation, name="loss"):
-        self.loss = loss
-        self.operation = operation
-        self.name = name
-
-    def __call__(self, graph, label):
-        return self.operation(self.loss(logits=graph, labels=label), name=self.name)
-        
-        
-class MeanSoftMaxLoss(object):
+def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True):
     """
     Simple CrossEntropy loss.
     Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
     
     **Parameters**
-    
-      name: Scope name
+      logits:
+      labels:
       add_regularization_losses: Regulize the loss???
     
     """
 
-    def __init__(self, name="loss", add_regularization_losses=True):
-        self.name = name
-        self.add_regularization_losses = add_regularization_losses
+    with tf.variable_scope('cross_entropy_loss'):
 
-    def __call__(self, graph, label):
-    
         loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
-                                          logits=graph, labels=label), name=self.name)
-    
-        if self.add_regularization_losses:
+                                          logits=logits, labels=labels), name=tf.GraphKeys.LOSSES)
+        
+        if add_regularization_losses:
             regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
-            return tf.add_n([loss] + regularization_losses, name='total_loss')
+            return tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES)
         else:
             return loss
             
-class MeanSoftMaxLossCenterLoss(object):
+def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0.9, factor=0.01):
     """
     Implementation of the CrossEntropy + Center Loss from the paper
     "A Discriminative Feature Learning Approach for Deep Face Recognition"(http://ydwen.github.io/papers/WenECCV16.pdf)
     
     **Parameters**
-
-      name: Scope name
+      logits:
+      prelogits:
+      labels:
+      n_classes: Number of classes of your task
       alpha: Alpha factor ((1-alpha)*centers-prelogits)
       factor: Weight factor of the center loss
-      n_classes: Number of classes of your task
+
     """
-    def __init__(self, name="loss", alpha=0.9, factor=0.01, n_classes=10):
-        self.name = name
+    # Cross entropy
+    with tf.variable_scope('cross_entropy_loss'):
+        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
+                                          logits=logits, labels=labels), name=tf.GraphKeys.LOSSES)
 
-        self.n_classes = n_classes
-        self.alpha = alpha
-        self.factor = factor
+    # Appending center loss        
+    with tf.variable_scope('center_loss'):
+        n_features = prelogits.get_shape()[1]
+        
+        centers = tf.get_variable('centers', [n_classes, n_features], dtype=tf.float32,
+            initializer=tf.constant_initializer(0), trainable=False)
+            
+        label = tf.reshape(labels, [-1])
+        centers_batch = tf.gather(centers, labels)
+        diff = (1 - alpha) * (centers_batch - prelogits)
+        centers = tf.scatter_sub(centers, labels, diff)
+        center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))       
+        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * factor)
 
+    # Adding the regularizers in the loss
+    with tf.variable_scope('total_loss'):
+        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
+        total_loss =  tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES)
 
-    def __call__(self, logits, prelogits, label):           
-        # Cross entropy
-        with tf.variable_scope('cross_entropy_loss'):
-            loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
-                                              logits=logits, labels=label), name=self.name)
+    loss = dict()
+    loss['loss'] = total_loss
+    loss['centers'] = centers
 
-        # Appending center loss        
-        with tf.variable_scope('center_loss'):
-            n_features = prelogits.get_shape()[1]
-            
-            centers = tf.get_variable('centers', [self.n_classes, n_features], dtype=tf.float32,
-                initializer=tf.constant_initializer(0), trainable=False)
-                
-            label = tf.reshape(label, [-1])
-            centers_batch = tf.gather(centers, label)
-            diff = (1 - self.alpha) * (centers_batch - prelogits)
-            centers = tf.scatter_sub(centers, label, diff)
-            center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))       
-            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
-    
-        # Adding the regularizers in the loss
-        with tf.variable_scope('total_loss'):
-            regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
-            total_loss =  tf.add_n([loss] + regularization_losses, name='total_loss')
-            
-        return total_loss, centers
+    return loss
 
diff --git a/bob/learn/tensorflow/loss/ContrastiveLoss.py b/bob/learn/tensorflow/loss/ContrastiveLoss.py
index 4c25a981..1ec9ace5 100755
--- a/bob/learn/tensorflow/loss/ContrastiveLoss.py
+++ b/bob/learn/tensorflow/loss/ContrastiveLoss.py
@@ -1,17 +1,15 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 10 Aug 2016 16:38 CEST
 
 import logging
 logger = logging.getLogger("bob.learn.tensorflow")
 import tensorflow as tf
 
-from .BaseLoss import BaseLoss
 from bob.learn.tensorflow.utils import compute_euclidean_distance
 
 
-class ContrastiveLoss(BaseLoss):
+def contrastive_loss(left_embedding, right_embedding, labels, contrastive_margin=1.0):
     """
     Compute the contrastive loss as in
 
@@ -27,7 +25,7 @@ class ContrastiveLoss(BaseLoss):
     right_feature:
       Second element of the pair
 
-    label:
+    labels:
       Label of the pair (0 or 1)
 
     margin:
@@ -35,30 +33,25 @@ class ContrastiveLoss(BaseLoss):
 
     """
 
-    def __init__(self, contrastive_margin=1.0):
-        self.contrastive_margin = contrastive_margin
+    with tf.name_scope("contrastive_loss"):
+        labels = tf.to_float(labels)
+        
+        left_embedding = tf.nn.l2_normalize(left_embedding, 1)
+        right_embedding  = tf.nn.l2_normalize(right_embedding, 1)
 
-    def __call__(self, label, left_feature, right_feature):
-        with tf.name_scope("contrastive_loss"):
-            label = tf.to_float(label)
-            
-            left_feature = tf.nn.l2_normalize(left_feature, 1)
-            right_feature  = tf.nn.l2_normalize(right_feature, 1)
+        one = tf.constant(1.0)
 
-            one = tf.constant(1.0)
+        d = compute_euclidean_distance(left_embedding, right_embedding)
+        within_class = tf.multiply(one - labels, tf.square(d))  # (1-Y)*(d^2)
+        
+        max_part = tf.square(tf.maximum(contrastive_margin - d, 0))
+        between_class = tf.multiply(labels, max_part)  # (Y) * max((margin - d)^2, 0)
 
-            d = compute_euclidean_distance(left_feature, right_feature)
-            within_class = tf.multiply(one - label, tf.square(d))  # (1-Y)*(d^2)
-            
-            
-            max_part = tf.square(tf.maximum(self.contrastive_margin - d, 0))
-            between_class = tf.multiply(label, max_part)  # (Y) * max((margin - d)^2, 0)
+        loss =  0.5 * (within_class + between_class)
 
-            loss = 0.5 * (within_class + between_class)
+        loss_dict = dict()
+        loss_dict['loss'] = tf.reduce_mean(loss, name=tf.GraphKeys.LOSSES)
+        loss_dict['between_class'] = tf.reduce_mean(between_class, name=tf.GraphKeys.LOSSES)
+        loss_dict['within_class'] = tf.reduce_mean(within_class, name=tf.GraphKeys.LOSSES)
 
-            loss_dict = dict()
-            loss_dict['loss'] = tf.reduce_mean(loss)
-            loss_dict['between_class'] = tf.reduce_mean(between_class)
-            loss_dict['within_class'] = tf.reduce_mean(within_class)
-
-            return loss_dict
+        return loss_dict
diff --git a/bob/learn/tensorflow/loss/TripletAverageLoss.py b/bob/learn/tensorflow/loss/TripletAverageLoss.py
deleted file mode 100755
index bcb7bea8..00000000
--- a/bob/learn/tensorflow/loss/TripletAverageLoss.py
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 10 Aug 2016 16:38 CEST
-
-import logging
-logger = logging.getLogger("bob.learn.tensorflow")
-import tensorflow as tf
-
-from .BaseLoss import BaseLoss
-from bob.learn.tensorflow.utils import compute_euclidean_distance
-
-
-class TripletAverageLoss(BaseLoss):
-    """
-    Compute the triplet loss as in
-
-    Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
-    "Facenet: A unified embedding for face recognition and clustering."
-    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
-
-    :math:`L  = sum(  |f_a - f_p|^2 - |f_a - f_n|^2  + \lambda)`
-
-    **Parameters**
-
-    left_feature:
-      First element of the pair
-
-    right_feature:
-      Second element of the pair
-
-    label:
-      Label of the pair (0 or 1)
-
-    margin:
-      Contrastive margin
-
-    """
-
-    def __init__(self, margin=0.1):
-        self.margin = margin
-
-    def __call__(self, anchor_embedding, positive_embedding, negative_embedding):
-
-        with tf.name_scope("triplet_loss"):
-            # Normalize
-            anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
-            positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive")
-            negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative")
-
-            anchor_mean = tf.reduce_mean(anchor_embedding, 0)
-
-            d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_mean, positive_embedding)), 1)
-            d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_mean, negative_embedding)), 1)
-
-            basic_loss = tf.add(tf.subtract(d_positive, d_negative), self.margin)
-            loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0)
-
-            return loss, tf.reduce_mean(d_negative), tf.reduce_mean(d_positive)
-
diff --git a/bob/learn/tensorflow/loss/TripletFisherLoss.py b/bob/learn/tensorflow/loss/TripletFisherLoss.py
deleted file mode 100755
index 54c0ad02..00000000
--- a/bob/learn/tensorflow/loss/TripletFisherLoss.py
+++ /dev/null
@@ -1,62 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 10 Aug 2016 16:38 CEST
-
-import logging
-logger = logging.getLogger("bob.learn.tensorflow")
-import tensorflow as tf
-
-from .BaseLoss import BaseLoss
-from bob.learn.tensorflow.utils import compute_euclidean_distance
-
-
-class TripletFisherLoss(BaseLoss):
-    """
-    """
-
-    def __init__(self):
-        pass
-
-    def __call__(self, anchor_embedding, positive_embedding, negative_embedding):
-
-        with tf.name_scope("triplet_loss"):
-            # Normalize
-            anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
-            positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive")
-            negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative")
-
-            average_class = tf.reduce_mean(anchor_embedding, 0)
-            average_total = tf.div(tf.add(tf.reduce_mean(anchor_embedding, axis=0),\
-                            tf.reduce_mean(negative_embedding, axis=0)), 2)
-
-            length = anchor_embedding.get_shape().as_list()[0]
-            dim = anchor_embedding.get_shape().as_list()[1]
-            split_positive = tf.unstack(positive_embedding, num=length, axis=0)
-            split_negative = tf.unstack(negative_embedding, num=length, axis=0)
-
-            Sw = None
-            Sb = None
-            for s in zip(split_positive, split_negative):
-                positive = s[0]
-                negative = s[1]
-
-                buffer_sw = tf.reshape(tf.subtract(positive, average_class), shape=(dim, 1))
-                buffer_sw = tf.matmul(buffer_sw, tf.reshape(buffer_sw, shape=(1, dim)))
-
-                buffer_sb = tf.reshape(tf.subtract(negative, average_total), shape=(dim, 1))
-                buffer_sb = tf.matmul(buffer_sb, tf.reshape(buffer_sb, shape=(1, dim)))
-
-                if Sw is None:
-                    Sw = buffer_sw
-                    Sb = buffer_sb
-                else:
-                    Sw = tf.add(Sw, buffer_sw)
-                    Sb = tf.add(Sb, buffer_sb)
-
-            # Sw = tf.trace(Sw)
-            # Sb = tf.trace(Sb)
-            #loss = tf.trace(tf.div(Sb, Sw))
-            loss = tf.trace(tf.div(Sw, Sb))
-
-            return loss, tf.trace(Sb), tf.trace(Sw)
diff --git a/bob/learn/tensorflow/loss/TripletLoss.py b/bob/learn/tensorflow/loss/TripletLoss.py
index 4478a12d..c642507c 100755
--- a/bob/learn/tensorflow/loss/TripletLoss.py
+++ b/bob/learn/tensorflow/loss/TripletLoss.py
@@ -1,17 +1,15 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 10 Aug 2016 16:38 CEST
 
 import logging
 logger = logging.getLogger("bob.learn.tensorflow")
 import tensorflow as tf
 
-from .BaseLoss import BaseLoss
 from bob.learn.tensorflow.utils import compute_euclidean_distance
 
 
-class TripletLoss(BaseLoss):
+def triplet_loss(anchor_embedding, positive_embedding, negative_embedding, margin=5.0):
     """
     Compute the triplet loss as in
 
@@ -37,26 +35,110 @@ class TripletLoss(BaseLoss):
 
     """
 
-    def __init__(self, margin=5.0):
-        self.margin = margin
+    with tf.name_scope("triplet_loss"):
+        # Normalize
+        anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
+        positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive")
+        negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative")
+
+        d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, positive_embedding)), 1)
+        d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, negative_embedding)), 1)
+
+        basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin)
+        loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0, name=tf.GraphKeys.LOSSES)
+
+        loss_dict = dict()
+        loss_dict['loss'] = loss
+        loss_dict['between_class'] = tf.reduce_mean(d_negative)
+        loss_dict['within_class'] = tf.reduce_mean(d_positive)
+
+        return loss_dict
+        
+
+def triplet_fisher_loss(anchor_embedding, positive_embedding, negative_embedding):
+
+    with tf.name_scope("triplet_loss"):
+        # Normalize
+        anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
+        positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive")
+        negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative")
+
+        average_class = tf.reduce_mean(anchor_embedding, 0)
+        average_total = tf.div(tf.add(tf.reduce_mean(anchor_embedding, axis=0),\
+                        tf.reduce_mean(negative_embedding, axis=0)), 2)
+
+        length = anchor_embedding.get_shape().as_list()[0]
+        dim = anchor_embedding.get_shape().as_list()[1]
+        split_positive = tf.unstack(positive_embedding, num=length, axis=0)
+        split_negative = tf.unstack(negative_embedding, num=length, axis=0)
+
+        Sw = None
+        Sb = None
+        for s in zip(split_positive, split_negative):
+            positive = s[0]
+            negative = s[1]
+
+            buffer_sw = tf.reshape(tf.subtract(positive, average_class), shape=(dim, 1))
+            buffer_sw = tf.matmul(buffer_sw, tf.reshape(buffer_sw, shape=(1, dim)))
+
+            buffer_sb = tf.reshape(tf.subtract(negative, average_total), shape=(dim, 1))
+            buffer_sb = tf.matmul(buffer_sb, tf.reshape(buffer_sb, shape=(1, dim)))
+
+            if Sw is None:
+                Sw = buffer_sw
+                Sb = buffer_sb
+            else:
+                Sw = tf.add(Sw, buffer_sw)
+                Sb = tf.add(Sb, buffer_sb)
+
+        # Sw = tf.trace(Sw)
+        # Sb = tf.trace(Sb)
+        #loss = tf.trace(tf.div(Sb, Sw))
+        loss = tf.trace(tf.div(Sw, Sb), name=tf.GraphKeys.LOSSES)
+
+        return loss, tf.trace(Sb), tf.trace(Sw)
+        
+        
+def triplet_average_loss(anchor_embedding, positive_embedding, negative_embedding, margin=5.0):
+    """
+    Compute the triplet loss as in
+
+    Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
+    "Facenet: A unified embedding for face recognition and clustering."
+    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
+
+    :math:`L  = sum(  |f_a - f_p|^2 - |f_a - f_n|^2  + \lambda)`
+
+    **Parameters**
+
+    left_feature:
+      First element of the pair
+
+    right_feature:
+      Second element of the pair
+
+    label:
+      Label of the pair (0 or 1)
+
+    margin:
+      Contrastive margin
+
+    """
 
-    def __call__(self, anchor_embedding, positive_embedding, negative_embedding):
+    with tf.name_scope("triplet_loss"):
+        # Normalize
+        anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
+        positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive")
+        negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative")
 
-        with tf.name_scope("triplet_loss"):
-            # Normalize
-            anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
-            positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive")
-            negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative")
+        anchor_mean = tf.reduce_mean(anchor_embedding, 0)
 
-            d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, positive_embedding)), 1)
-            d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, negative_embedding)), 1)
+        d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_mean, positive_embedding)), 1)
+        d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_mean, negative_embedding)), 1)
 
-            basic_loss = tf.add(tf.subtract(d_positive, d_negative), self.margin)
-            loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0)
+        basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin)
+        loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0, name=tf.GraphKeys.LOSSES)
 
-            loss_dict = dict()
-            loss_dict['loss'] = loss
-            loss_dict['between_class'] = tf.reduce_mean(d_negative)
-            loss_dict['within_class'] = tf.reduce_mean(d_positive)
+        return loss, tf.reduce_mean(d_negative), tf.reduce_mean(d_positive)        
 
-            return loss_dict
+        
diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py
index 19e58a34..1cf46711 100755
--- a/bob/learn/tensorflow/loss/__init__.py
+++ b/bob/learn/tensorflow/loss/__init__.py
@@ -1,9 +1,7 @@
-from .BaseLoss import BaseLoss, MeanSoftMaxLoss, MeanSoftMaxLossCenterLoss
-from .ContrastiveLoss import ContrastiveLoss
-from .TripletLoss import TripletLoss
-from .TripletAverageLoss import TripletAverageLoss
-from .TripletFisherLoss import TripletFisherLoss
-from .NegLogLoss import NegLogLoss
+from .BaseLoss import mean_cross_entropy_loss, mean_cross_entropy_center_loss
+from .ContrastiveLoss import contrastive_loss
+from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
+#from .NegLogLoss import NegLogLoss
 
 
 # gets sphinx autodoc done right - don't remove it
@@ -21,13 +19,9 @@ def __appropriate__(*args):
   for obj in args: obj.__module__ = __name__
 
 __appropriate__(
-    BaseLoss,
-    ContrastiveLoss,
-    TripletLoss,
-    TripletFisherLoss,
-    TripletAverageLoss,
-    NegLogLoss,
-    MeanSoftMaxLoss
+    mean_cross_entropy_loss, mean_cross_entropy_center_loss,
+    contrastive_loss,
+    triplet_loss, triplet_average_loss, triplet_fisher_loss
     )
 __all__ = [_ for _ in dir() if not _.startswith('_')]
 
diff --git a/bob/learn/tensorflow/network/Chopra.py b/bob/learn/tensorflow/network/Chopra.py
index e8933ad0..2a4328f9 100755
--- a/bob/learn/tensorflow/network/Chopra.py
+++ b/bob/learn/tensorflow/network/Chopra.py
@@ -1,13 +1,22 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 11 May 2016 09:39:36 CEST 
 
 import tensorflow as tf
-from .utils import append_logits
 
+def chopra(inputs, conv1_kernel_size=[7, 7],
+           conv1_output=15,
 
-class Chopra(object):
+           pooling1_size=[2, 2],
+
+
+           conv2_kernel_size=[6, 6],
+           conv2_output=45,
+
+           pooling2_size=[4, 3],
+           fc1_output=250,
+           seed=10,
+           reuse=False,):
     """Class that creates the architecture presented in the paper:
 
     Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to
@@ -49,79 +58,41 @@ class Chopra(object):
 
         fc1_output:
         
-        n_classes: If None, no Fully COnnected layer with class output will be created
-
         seed:
     """
-    def __init__(self,
-                 conv1_kernel_size=[7, 7],
-                 conv1_output=15,
-
-                 pooling1_size=[2, 2],
-
-
-                 conv2_kernel_size=[6, 6],
-                 conv2_output=45,
-
-                 pooling2_size=[4, 3],
-
-                 fc1_output=250,
-                 n_classes=None,
-                 seed=10):
-
-            self.conv1_kernel_size = conv1_kernel_size
-            self.conv1_output = conv1_output
-            self.pooling1_size = pooling1_size
-
-            self.conv2_output = conv2_output
-            self.conv2_kernel_size = conv2_kernel_size
-            self.pooling2_size = pooling2_size
-
-            self.fc1_output = fc1_output
-
-            self.seed = seed
-            self.n_classes = n_classes
-
-
-    def __call__(self, inputs, reuse=False, end_point='logits'):
-        slim = tf.contrib.slim
-
-        end_points = dict()
-        
-        initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=self.seed)
-
-        graph = slim.conv2d(inputs, self.conv1_output, self.conv1_kernel_size, activation_fn=tf.nn.relu,
-                            stride=1,
-                            weights_initializer=initializer,
-                            scope='conv1',
-                            reuse=reuse)
-        end_points['conv1'] = graph
-        
-        graph = slim.max_pool2d(graph, self.pooling1_size, scope='pool1')
-        end_points['pool1'] = graph
-
-        graph = slim.conv2d(graph, self.conv2_output, self.conv2_kernel_size, activation_fn=tf.nn.relu,
-                            stride=1,
-                            weights_initializer=initializer,
-                            scope='conv2', reuse=reuse)
-        end_points['conv2'] = graph
-        graph = slim.max_pool2d(graph, self.pooling2_size, scope='pool2')
-        end_points['pool2'] = graph        
-
-        graph = slim.flatten(graph, scope='flatten1')
-        end_points['flatten1'] = graph        
-
-        graph = slim.fully_connected(graph, self.fc1_output,
-                                     weights_initializer=initializer,
-                                     activation_fn=None,
-                                     scope='fc1',
-                                     reuse=reuse)
-        end_points['fc1'] = graph                                     
-                                     
-        if self.n_classes is not None:
-            # Appending the logits layer
-            graph = append_logits(graph, self.n_classes, reuse)
-            end_points['logits'] = graph
-        
-        return end_points[end_point]
+    slim = tf.contrib.slim
+
+    end_points = dict()
+    
+    initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=seed)
+
+    graph = slim.conv2d(inputs, conv1_output, conv1_kernel_size, activation_fn=tf.nn.relu,
+                        stride=1,
+                        weights_initializer=initializer,
+                        scope='conv1',
+                        reuse=reuse)
+    end_points['conv1'] = graph
+    
+    graph = slim.max_pool2d(graph, pooling1_size, scope='pool1')
+    end_points['pool1'] = graph
+
+    graph = slim.conv2d(graph, conv2_output, conv2_kernel_size, activation_fn=tf.nn.relu,
+                        stride=1,
+                        weights_initializer=initializer,
+                        scope='conv2', reuse=reuse)
+    end_points['conv2'] = graph
+    graph = slim.max_pool2d(graph, pooling2_size, scope='pool2')
+    end_points['pool2'] = graph        
+
+    graph = slim.flatten(graph, scope='flatten1')
+    end_points['flatten1'] = graph        
+
+    graph = slim.fully_connected(graph, fc1_output,
+                                 weights_initializer=initializer,
+                                 activation_fn=None,
+                                 scope='fc1',
+                                 reuse=reuse)
+    end_points['fc1'] = graph                                     
+    
+    return graph, end_points
 
diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py
index f0997937..68ed993e 100755
--- a/bob/learn/tensorflow/network/__init__.py
+++ b/bob/learn/tensorflow/network/__init__.py
@@ -1,4 +1,4 @@
-from .Chopra import Chopra
+from .Chopra import chopra
 from .LightCNN9 import LightCNN9
 from .LightCNN29 import LightCNN29
 from .Dummy import Dummy
diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py
index 8ce0ed8b..ceb35a54 100755
--- a/bob/learn/tensorflow/network/utils.py
+++ b/bob/learn/tensorflow/network/utils.py
@@ -6,12 +6,9 @@ import tensorflow as tf
 slim = tf.contrib.slim
 
 
-def append_logits(graph, n_classes, reuse):
-    graph = slim.fully_connected(graph, n_classes, activation_fn=None, 
-               weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 
-               weights_regularizer=slim.l2_regularizer(0.1),
+def append_logits(graph, n_classes, reuse=False, l2_regularizer=0.001, weights_std=0.1):
+    return slim.fully_connected(graph, n_classes, activation_fn=None, 
+               weights_initializer=tf.truncated_normal_initializer(stddev=weights_std), 
+               weights_regularizer=slim.l2_regularizer(l2_regularizer),
                scope='Logits', reuse=reuse)
 
-    return graph
-
-
diff --git a/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py b/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py
index 9447bf62..7999b635 100755
--- a/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py
+++ b/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py
@@ -84,18 +84,14 @@ def main(argv=None):
 
     create_directories_safe(os.path.dirname(output_file))
 
-    import ipdb; ipdb.set_trace()
-
     n_files = len(enroll)
     with tf.python_io.TFRecordWriter(output_file) as writer:
       for e, p, i in zip(enroll, probe, range(len(enroll)) ):
         logger.info('Processing pair %d out of %d', i + 1, n_files)
-
-        e_path = e.make_path(data_path, extension)
-        p_path = p.make_path(data_path, extension)
         
-        if os.path.exists(p_path) and os.path.exists(e_path):        
-            for path in [e_path, p_path]:
+        if os.path.exists(e.make_path(data_path, extension)) and os.path.exists(p.make_path(data_path, extension)):
+            for f in [e, p]:
+                path = f.make_path(data_path, extension)
                 data = bob.io.image.to_matplotlib(bob.io.base.load(path)).astype(data_type)
                 data = data.tostring()
 
diff --git a/bob/learn/tensorflow/script/train.py b/bob/learn/tensorflow/script/train.py
index d177aca7..0c20262e 100755
--- a/bob/learn/tensorflow/script/train.py
+++ b/bob/learn/tensorflow/script/train.py
@@ -73,7 +73,7 @@ def main():
         return True
 
     config = imp.load_source('config', args['<configuration>'])
-
+    
     # Cleaning all variables in case you are loading the checkpoint
     tf.reset_default_graph() if os.path.exists(output_dir) else None
 
@@ -107,9 +107,9 @@ def main():
         train_graph = None
         validation_graph = None
         validate_with_embeddings = False
-        
-        if hasattr(config, 'train_graph'):
-            train_graph = config.train_graph
+
+        if hasattr(config, 'logits'):
+            train_graph = config.logits
             if hasattr(config, 'validation_graph'):
                 validation_graph = config.validation_graph
             
diff --git a/bob/learn/tensorflow/test/test_cnn.py b/bob/learn/tensorflow/test/test_cnn.py
index e2b15776..86dfdfd0 100755
--- a/bob/learn/tensorflow/test/test_cnn.py
+++ b/bob/learn/tensorflow/test/test_cnn.py
@@ -5,11 +5,13 @@
 
 import numpy
 from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, ImageAugmentation, ScaleFactor, Linear
-from bob.learn.tensorflow.network import Chopra
-from bob.learn.tensorflow.loss import MeanSoftMaxLoss, ContrastiveLoss, TripletLoss
+from bob.learn.tensorflow.network import chopra
+from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss, triplet_loss
 from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant
-from .test_cnn_scratch import validate_network
+from bob.learn.tensorflow.test.test_cnn_scratch import validate_network
 from bob.learn.tensorflow.network import Embedding, LightCNN9
+from bob.learn.tensorflow.network.utils import append_logits
+
 
 from bob.learn.tensorflow.utils import load_mnist
 import tensorflow as tf
@@ -92,15 +94,15 @@ def test_cnn_trainer():
 
     directory = "./temp/cnn"
 
+    # Preparing the graph
+    inputs = train_data_shuffler("data", from_queue=True)
+    labels = train_data_shuffler("label", from_queue=True)
+    logits = append_logits(chopra(inputs, seed=seed)[0], n_classes=10)
+    
     # Loss for the softmax
-    loss = MeanSoftMaxLoss()
-
-    # Preparing the architecture
-    architecture = Chopra(seed=seed, n_classes=10)
-    input_pl = train_data_shuffler("data", from_queue=True)
+    loss = mean_cross_entropy_loss(logits, labels)
     
-    graph = architecture(input_pl)
-    embedding = Embedding(train_data_shuffler("data", from_queue=False), graph)
+    embedding = Embedding(train_data_shuffler("data", from_queue=False), logits)
 
     # One graph trainer
     trainer = Trainer(train_data_shuffler,
@@ -108,7 +110,7 @@ def test_cnn_trainer():
                       analizer=None,
                       temp_dir=directory
                       )
-    trainer.create_network_from_scratch(graph=graph,
+    trainer.create_network_from_scratch(graph=logits,
                                         loss=loss,
                                         learning_rate=constant(0.01, name="regular_lr"),
                                         optimizer=tf.train.GradientDescentOptimizer(0.01),
@@ -122,7 +124,7 @@ def test_cnn_trainer():
     assert accuracy > 20.
     shutil.rmtree(directory)
     del trainer
-    del graph
+    del logits
     tf.reset_default_graph()
     assert len(tf.global_variables())==0
 
@@ -139,7 +141,6 @@ def test_lightcnn_trainer():
     validation_data = numpy.vstack((validation_data, numpy.random.normal(2, 0.2, size=(100, 128, 128, 1))))
     validation_labels = numpy.hstack((numpy.zeros(100), numpy.ones(100))).astype("uint64")
 
-
     # Creating datashufflers
     data_augmentation = ImageAugmentation()
     train_data_shuffler = Memory(train_data, train_labels,
@@ -150,15 +151,17 @@ def test_lightcnn_trainer():
 
     directory = "./temp/cnn"
 
-    # Loss for the softmax
-    loss = MeanSoftMaxLoss()
-
     # Preparing the architecture
     architecture = LightCNN9(seed=seed,
                              n_classes=2)
-    input_pl = train_data_shuffler("data", from_queue=True)
-    graph = architecture(input_pl, end_point="logits")
-    embedding = Embedding(train_data_shuffler("data", from_queue=False), graph)
+    inputs = train_data_shuffler("data", from_queue=True)
+    labels = train_data_shuffler("label", from_queue=True)
+    logits = architecture(inputs, end_point="logits")
+    embedding = Embedding(train_data_shuffler("data", from_queue=False), logits)
+    
+    # Loss for the softmax
+    loss = mean_cross_entropy_loss(logits, labels)
+
 
     # One graph trainer
     trainer = Trainer(train_data_shuffler,
@@ -166,7 +169,7 @@ def test_lightcnn_trainer():
                       analizer=None,
                       temp_dir=directory
                       )
-    trainer.create_network_from_scratch(graph=graph,
+    trainer.create_network_from_scratch(graph=logits,
                                         loss=loss,
                                         learning_rate=constant(0.001, name="regular_lr"),
                                         optimizer=tf.train.GradientDescentOptimizer(0.001),
@@ -179,7 +182,7 @@ def test_lightcnn_trainer():
     assert True
     shutil.rmtree(directory)
     del trainer
-    del graph
+    del logits
     tf.reset_default_graph()
     assert len(tf.global_variables())==0    
 
@@ -202,16 +205,15 @@ def test_siamesecnn_trainer():
                                              normalizer=ScaleFactor())
     directory = "./temp/siamesecnn"
 
-    # Preparing the architecture
-    architecture = Chopra(seed=seed)
+    # Building the graph
+    inputs = train_data_shuffler("data")
+    labels = train_data_shuffler("label")
+    graph = dict()
+    graph['left'] = chopra(inputs['left'])[0]
+    graph['right'] = chopra(inputs['right'], reuse=True)[0]
 
     # Loss for the Siamese
-    loss = ContrastiveLoss(contrastive_margin=4.)
-
-    input_pl = train_data_shuffler("data")
-    graph = dict()
-    graph['left'] = architecture(input_pl['left'], end_point="fc1")
-    graph['right'] = architecture(input_pl['right'], reuse=True, end_point="fc1")
+    loss = contrastive_loss(graph['left'], graph['right'], labels, contrastive_margin=4.)
 
     trainer = SiameseTrainer(train_data_shuffler,
                              iterations=iterations,
@@ -229,7 +231,6 @@ def test_siamesecnn_trainer():
     assert eer < 0.15
     shutil.rmtree(directory)
 
-    del architecture
     del trainer  # Just to clean tf.variables
     tf.reset_default_graph()
     assert len(tf.global_variables())==0    
@@ -254,17 +255,14 @@ def test_tripletcnn_trainer():
 
     directory = "./temp/tripletcnn"
 
-    # Preparing the architecture
-    architecture = Chopra(seed=seed, fc1_output=10)
-
-    # Loss for the Siamese
-    loss = TripletLoss(margin=4.)
-
-    input_pl = train_data_shuffler("data")
+    inputs = train_data_shuffler("data")
+    labels = train_data_shuffler("label")
     graph = dict()
-    graph['anchor'] = architecture(input_pl['anchor'], end_point="fc1")
-    graph['positive'] = architecture(input_pl['positive'], reuse=True, end_point="fc1")
-    graph['negative'] = architecture(input_pl['negative'], reuse=True, end_point="fc1")
+    graph['anchor'] = chopra(inputs['anchor'])[0]
+    graph['positive'] = chopra(inputs['positive'], reuse=True)[0]
+    graph['negative'] = chopra(inputs['negative'], reuse=True)[0]
+
+    loss = triplet_loss(graph['anchor'], graph['positive'], graph['negative'])
 
     # One graph trainer
     trainer = TripletTrainer(train_data_shuffler,
@@ -283,7 +281,6 @@ def test_tripletcnn_trainer():
     assert eer < 0.15
     shutil.rmtree(directory)
 
-    del architecture
     del trainer  # Just to clean tf.variables
     tf.reset_default_graph()
     assert len(tf.global_variables())==0    
diff --git a/bob/learn/tensorflow/test/test_cnn_other_losses.py b/bob/learn/tensorflow/test/test_cnn_other_losses.py
index bc5f3ae7..dfcbc34e 100755
--- a/bob/learn/tensorflow/test/test_cnn_other_losses.py
+++ b/bob/learn/tensorflow/test/test_cnn_other_losses.py
@@ -5,9 +5,10 @@
 
 import numpy
 from bob.learn.tensorflow.datashuffler import TFRecord
-from bob.learn.tensorflow.loss import MeanSoftMaxLossCenterLoss, MeanSoftMaxLoss
+from bob.learn.tensorflow.loss import mean_cross_entropy_loss, mean_cross_entropy_center_loss
 from bob.learn.tensorflow.trainers import Trainer, constant
 from bob.learn.tensorflow.utils import load_mnist
+from bob.learn.tensorflow.network.utils import append_logits
 import tensorflow as tf
 import shutil
 import os
@@ -25,7 +26,7 @@ directory = "./temp/cnn_scratch"
 slim = tf.contrib.slim
 
 
-def scratch_network_embeding_example(train_data_shuffler, reuse=False, get_embedding=False):
+def scratch_network_embeding_example(train_data_shuffler, reuse=False):
 
     if isinstance(train_data_shuffler, tf.Tensor):
         inputs = train_data_shuffler
@@ -41,19 +42,7 @@ def scratch_network_embeding_example(train_data_shuffler, reuse=False, get_embed
     prelogits = slim.fully_connected(graph, 30, activation_fn=None, scope='fc1',
                                  weights_initializer=initializer, reuse=reuse)
 
-    if get_embedding:
-        embedding = tf.nn.l2_normalize(prelogits, dim=1, name="embedding")
-        return embedding, None
-    else:
-        logits = slim.fully_connected(prelogits, 10, activation_fn=None, scope='logits',
-                                     weights_initializer=initializer, reuse=reuse)
-    
-    #logits_prelogits = dict()
-    #logits_prelogits['logits'] = logits
-    #logits_prelogits['prelogits'] = prelogits
-   
-    return logits, prelogits
-
+    return prelogits
 
 def test_center_loss_tfrecord_embedding_validation():
     tf.reset_default_graph()
@@ -95,6 +84,7 @@ def test_center_loss_tfrecord_embedding_validation():
     create_tf_record(tfrecords_filename_val, validation_data, validation_labels)   
     filename_queue_val = tf.train.string_input_producer([tfrecords_filename_val], num_epochs=55, name="input_validation")
 
+
     # Creating the CNN using the TFRecord as input
     train_data_shuffler  = TFRecord(filename_queue=filename_queue,
                                     batch_size=batch_size)
@@ -102,12 +92,15 @@ def test_center_loss_tfrecord_embedding_validation():
     validation_data_shuffler  = TFRecord(filename_queue=filename_queue_val,
                                          batch_size=2000)
                                          
-    graph, prelogits = scratch_network_embeding_example(train_data_shuffler)
-    validation_graph,_ = scratch_network_embeding_example(validation_data_shuffler, reuse=True, get_embedding=True)
+    prelogits = scratch_network_embeding_example(train_data_shuffler)
+    logits = append_logits(prelogits, n_classes=10)
+    validation_graph = tf.nn.l2_normalize(scratch_network_embeding_example(validation_data_shuffler, reuse=True), 1)
+
+    labels = train_data_shuffler("label", from_queue=False)
     
     # Setting the placeholders
     # Loss for the softmax
-    loss = MeanSoftMaxLossCenterLoss(n_classes=10, factor=0.1)
+    loss =  mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes=10, factor=0.1)
 
     # One graph trainer
     trainer = Trainer(train_data_shuffler,
@@ -119,14 +112,13 @@ def test_center_loss_tfrecord_embedding_validation():
 
     learning_rate = constant(0.01, name="regular_lr")
 
-    trainer.create_network_from_scratch(graph=graph,
+    trainer.create_network_from_scratch(graph=logits,
                                         validation_graph=validation_graph,
                                         loss=loss,
                                         learning_rate=learning_rate,
                                         optimizer=tf.train.GradientDescentOptimizer(learning_rate),
                                         prelogits=prelogits
                                         )
-
     trainer.train()
 
     assert True
@@ -155,26 +147,8 @@ def test_center_loss_tfrecord_embedding_validation():
                       temp_dir=directory)
 
     trainer.create_network_from_file(directory)
-    
-    import ipdb; ipdb.set_trace();
-
     trainer.train()
     
-    """
-    
-    # Inference. TODO: Wrap this in a package
-    file_name = os.path.join(directory, "model.ckp.meta")
-    images = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
-    graph ,_ = scratch_network_embeding_example(images, reuse=False)
-
-    session = tf.Session()
-    session.run(tf.global_variables_initializer())
-    saver = tf.train.import_meta_graph(file_name, clear_devices=True)
-    saver.restore(session, tf.train.latest_checkpoint(os.path.dirname("./temp/cnn_scratch/")))
-    data = numpy.random.rand(2, 28, 28, 1).astype("float32")
-    assert session.run(graph, feed_dict={images: data}).shape == (2, 10)
-    """
-
     os.remove(tfrecords_filename)
     os.remove(tfrecords_filename_val)    
 
diff --git a/bob/learn/tensorflow/test/test_cnn_scratch.py b/bob/learn/tensorflow/test/test_cnn_scratch.py
index 836689f7..04b3d6f3 100755
--- a/bob/learn/tensorflow/test/test_cnn_scratch.py
+++ b/bob/learn/tensorflow/test/test_cnn_scratch.py
@@ -6,7 +6,7 @@
 import numpy
 from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, ScaleFactor, Linear, TFRecord
 from bob.learn.tensorflow.network import Embedding
-from bob.learn.tensorflow.loss import BaseLoss, MeanSoftMaxLoss
+from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss, triplet_loss
 from bob.learn.tensorflow.trainers import Trainer, constant
 from bob.learn.tensorflow.utils import load_mnist
 import tensorflow as tf
diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py
index cde9bb5a..9beba407 100755
--- a/bob/learn/tensorflow/trainers/SiameseTrainer.py
+++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py
@@ -97,8 +97,6 @@ class SiameseTrainer(Trainer):
         self.validation_graph = None
                 
         self.loss = None
-        
-        self.predictor = None
         self.validation_predictor = None        
         
         self.optimizer_class = None
@@ -139,9 +137,6 @@ class SiameseTrainer(Trainer):
             raise ValueError("`graph` should be a dictionary with two elements (`left`and `right`)")
 
         self.loss = loss
-        self.predictor = self.loss(self.label_ph,
-                                   self.graph["left"],
-                                   self.graph["right"])
         self.optimizer_class = optimizer
         self.learning_rate = learning_rate
 
@@ -156,9 +151,9 @@ class SiameseTrainer(Trainer):
         tf.add_to_collection("graph_right", self.graph['right'])
 
         # Saving pointers to the loss
-        tf.add_to_collection("predictor_loss", self.predictor['loss'])
-        tf.add_to_collection("predictor_between_class_loss", self.predictor['between_class'])
-        tf.add_to_collection("predictor_within_class_loss", self.predictor['within_class'])
+        tf.add_to_collection("loss", self.loss['loss'])
+        tf.add_to_collection("between_class_loss", self.loss['between_class'])
+        tf.add_to_collection("within_class_loss", self.loss['within_class'])
 
         # Saving the pointers to the placeholders
         tf.add_to_collection("data_ph_left", self.data_ph['left'])
@@ -167,7 +162,7 @@ class SiameseTrainer(Trainer):
 
         # Preparing the optimizer
         self.optimizer_class._learning_rate = self.learning_rate
-        self.optimizer = self.optimizer_class.minimize(self.predictor['loss'], global_step=self.global_step)
+        self.optimizer = self.optimizer_class.minimize(self.loss['loss'], global_step=self.global_step)
         tf.add_to_collection("optimizer", self.optimizer)
         tf.add_to_collection("learning_rate", self.learning_rate)
 
@@ -196,10 +191,10 @@ class SiameseTrainer(Trainer):
         self.label_ph = tf.get_collection("label_ph")[0]
 
         # Loading loss from the pointers
-        self.predictor = dict()
-        self.predictor['loss'] = tf.get_collection("predictor_loss")[0]
-        self.predictor['between_class'] = tf.get_collection("predictor_between_class_loss")[0]
-        self.predictor['within_class'] = tf.get_collection("predictor_within_class_loss")[0]
+        self.loss = dict()
+        self.loss['loss'] = tf.get_collection("loss")[0]
+        self.loss['between_class'] = tf.get_collection("between_class_loss")[0]
+        self.loss['within_class'] = tf.get_collection("within_class_loss")[0]
 
         # Loading other elements
         self.optimizer = tf.get_collection("optimizer")[0]
@@ -223,8 +218,8 @@ class SiameseTrainer(Trainer):
                 
         _, l, bt_class, wt_class, lr, summary = self.session.run([
                                                 self.optimizer,
-                                                self.predictor['loss'], self.predictor['between_class'],
-                                                self.predictor['within_class'],
+                                                self.loss['loss'], self.loss['between_class'],
+                                                self.loss['within_class'],
                                                 self.learning_rate, self.summaries_train], feed_dict=feed_dict)
 
         logger.info("Loss training set step={0} = {1}".format(step, l))
@@ -238,9 +233,9 @@ class SiameseTrainer(Trainer):
             tf.summary.histogram(var.op.name, var)
 
         # Train summary
-        tf.summary.scalar('loss', self.predictor['loss'])
-        tf.summary.scalar('between_class_loss', self.predictor['between_class'])
-        tf.summary.scalar('within_class_loss', self.predictor['within_class'])
+        tf.summary.scalar('loss', self.loss['loss'])
+        tf.summary.scalar('between_class_loss', self.loss['between_class'])
+        tf.summary.scalar('within_class_loss', self.loss['within_class'])
         tf.summary.scalar('lr', self.learning_rate)
         return tf.summary.merge_all()
 
@@ -257,7 +252,7 @@ class SiameseTrainer(Trainer):
         # Opening a new session for validation
         feed_dict = self.get_feed_dict(data_shuffler)
 
-        l, summary = self.session.run([self.predictor, self.summaries_validation], feed_dict=feed_dict)
+        l, summary = self.session.run([self.loss, self.summaries_validation], feed_dict=feed_dict)
         self.validation_summary_writter.add_summary(summary, step)
 
         #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index 03ac1670..25c660e8 100755
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -103,7 +103,6 @@ class Trainer(object):
                 
         self.loss = None
         
-        self.predictor = None
         self.validation_predictor = None  
         self.validate_with_embeddings = validate_with_embeddings      
         
@@ -242,35 +241,32 @@ class Trainer(object):
         # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
         self.centers = None
         if prelogits is not None:
-            self.predictor, self.centers = self.loss(self.graph, prelogits, self.label_ph)
+            self.loss = loss['loss']
+            self.centers = loss['centers']
             tf.add_to_collection("centers", self.centers)
+            tf.add_to_collection("loss", self.loss)
             tf.add_to_collection("prelogits", prelogits)
             self.prelogits = prelogits
-        else:
-            self.predictor = self.loss(self.graph, self.label_ph)
-        
+
         self.optimizer_class = optimizer
         self.learning_rate = learning_rate
         self.global_step = tf.contrib.framework.get_or_create_global_step()
 
         # Preparing the optimizer
         self.optimizer_class._learning_rate = self.learning_rate
-        self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step)
+        self.optimizer = self.optimizer_class.minimize(self.loss, global_step=self.global_step)
 
         # Saving all the variables
         self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), 
                                     keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
 
-        self.summaries_train = self.create_general_summary(self.predictor, self.graph, self.label_ph)
+        self.summaries_train = self.create_general_summary(self.loss, self.graph, self.label_ph)
 
         # SAving some variables
         tf.add_to_collection("global_step", self.global_step)
 
-            
+        tf.add_to_collection("loss", self.loss)
         tf.add_to_collection("graph", self.graph)
-        
-        tf.add_to_collection("predictor", self.predictor)
-
         tf.add_to_collection("data_ph", self.data_ph)
         tf.add_to_collection("label_ph", self.label_ph)
 
@@ -363,7 +359,7 @@ class Trainer(object):
         self.label_ph = tf.get_collection("label_ph")[0]
 
         self.graph = tf.get_collection("graph")[0]
-        self.predictor = tf.get_collection("predictor")[0]
+        self.loss = tf.get_collection("loss")[0]
 
         # Loding other elements
         self.optimizer = tf.get_collection("optimizer")[0]
@@ -418,15 +414,15 @@ class Trainer(object):
         if self.train_data_shuffler.prefetch:
             # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT        
             if self.centers is None:            
-                _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
+                _, l, lr, summary = self.session.run([self.optimizer, self.loss,
                                                       self.learning_rate, self.summaries_train])
             else:
-                _, l, lr, summary, _ = self.session.run([self.optimizer, self.predictor,
+                _, l, lr, summary, _ = self.session.run([self.optimizer, self.loss,
                                                       self.learning_rate, self.summaries_train, self.centers])
             
         else:
             feed_dict = self.get_feed_dict(self.train_data_shuffler)
-            _, l, lr, summary = self.session.run([self.optimizer, self.predictor,
+            _, l, lr, summary = self.session.run([self.optimizer, self.loss,
                                                   self.learning_rate, self.summaries_train], feed_dict=feed_dict)
 
         logger.info("Loss training set step={0} = {1}".format(step, l))
diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py
index a8a782ba..d52d3802 100755
--- a/bob/learn/tensorflow/trainers/TripletTrainer.py
+++ b/bob/learn/tensorflow/trainers/TripletTrainer.py
@@ -99,7 +99,6 @@ class TripletTrainer(Trainer):
                 
         self.loss = None
         
-        self.predictor = None
         self.validation_predictor = None        
         
         self.optimizer_class = None
@@ -139,9 +138,6 @@ class TripletTrainer(Trainer):
             raise ValueError("`graph` should be a dictionary with two elements (`anchor`, `positive` and `negative`)")
 
         self.loss = loss
-        self.predictor = self.loss(self.graph["anchor"],
-                                   self.graph["positive"],
-                                   self.graph["negative"])
         self.optimizer_class = optimizer
         self.learning_rate = learning_rate
 
@@ -158,9 +154,9 @@ class TripletTrainer(Trainer):
         tf.add_to_collection("graph_negative", self.graph['negative'])
 
         # Saving pointers to the loss
-        tf.add_to_collection("predictor_loss", self.predictor['loss'])
-        tf.add_to_collection("predictor_between_class_loss", self.predictor['between_class'])
-        tf.add_to_collection("predictor_within_class_loss", self.predictor['within_class'])
+        tf.add_to_collection("loss", self.loss['loss'])
+        tf.add_to_collection("between_class_loss", self.loss['between_class'])
+        tf.add_to_collection("within_class_loss", self.loss['within_class'])
 
         # Saving the pointers to the placeholders
         tf.add_to_collection("data_ph_anchor", self.data_ph['anchor'])
@@ -169,7 +165,7 @@ class TripletTrainer(Trainer):
 
         # Preparing the optimizer
         self.optimizer_class._learning_rate = self.learning_rate
-        self.optimizer = self.optimizer_class.minimize(self.predictor['loss'], global_step=self.global_step)
+        self.optimizer = self.optimizer_class.minimize(self.loss['loss'], global_step=self.global_step)
         tf.add_to_collection("optimizer", self.optimizer)
         tf.add_to_collection("learning_rate", self.learning_rate)
 
@@ -196,10 +192,10 @@ class TripletTrainer(Trainer):
         self.data_ph['negative'] = tf.get_collection("data_ph_negative")[0]
 
         # Loading loss from the pointers
-        self.predictor = dict()
-        self.predictor['loss'] = tf.get_collection("predictor_loss")[0]
-        self.predictor['between_class'] = tf.get_collection("predictor_between_class_loss")[0]
-        self.predictor['within_class'] = tf.get_collection("predictor_within_class_loss")[0]
+        self.loss = dict()
+        self.loss['loss'] = tf.get_collection("loss")[0]
+        self.loss['between_class'] = tf.get_collection("between_class_loss")[0]
+        self.loss['within_class'] = tf.get_collection("within_class_loss")[0]
 
         # Loading other elements
         self.optimizer = tf.get_collection("optimizer")[0]
@@ -221,8 +217,8 @@ class TripletTrainer(Trainer):
         feed_dict = self.get_feed_dict(self.train_data_shuffler)
         _, l, bt_class, wt_class, lr, summary = self.session.run([
                                                 self.optimizer,
-                                                self.predictor['loss'], self.predictor['between_class'],
-                                                self.predictor['within_class'],
+                                                self.loss['loss'], self.loss['between_class'],
+                                                self.loss['within_class'],
                                                 self.learning_rate, self.summaries_train], feed_dict=feed_dict)
 
         logger.info("Loss training set step={0} = {1}".format(step, l))
@@ -231,9 +227,9 @@ class TripletTrainer(Trainer):
     def create_general_summary(self):
 
         # Train summary
-        tf.summary.scalar('loss', self.predictor['loss'])
-        tf.summary.scalar('between_class_loss', self.predictor['between_class'])
-        tf.summary.scalar('within_class_loss', self.predictor['within_class'])
+        tf.summary.scalar('loss', self.loss['loss'])
+        tf.summary.scalar('between_class_loss', self.loss['between_class'])
+        tf.summary.scalar('within_class_loss', self.loss['within_class'])
         tf.summary.scalar('lr', self.learning_rate)
         return tf.summary.merge_all()
 
-- 
GitLab