diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py
index cbfeadfa6802b26d26c6fb0289b459b0949006c5..c088fa77a55ad16b8bc1328bb2e813d00a48936c 100755
--- a/bob/learn/tensorflow/loss/BaseLoss.py
+++ b/bob/learn/tensorflow/loss/BaseLoss.py
@@ -32,6 +32,7 @@ def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True):
         else:
             return loss
             
+
 def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0.9, factor=0.01):
     """
     Implementation of the CrossEntropy + Center Loss from the paper
@@ -58,7 +59,7 @@ def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0
         centers = tf.get_variable('centers', [n_classes, n_features], dtype=tf.float32,
             initializer=tf.constant_initializer(0), trainable=False)
             
-        label = tf.reshape(labels, [-1])
+        #label = tf.reshape(labels, [-1])
         centers_batch = tf.gather(centers, labels)
         diff = (1 - alpha) * (centers_batch - prelogits)
         centers = tf.scatter_sub(centers, labels, diff)
@@ -68,7 +69,7 @@ def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0
     # Adding the regularizers in the loss
     with tf.variable_scope('total_loss'):
         regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
-        total_loss =  tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES)
+        total_loss = tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES)
 
     loss = dict()
     loss['loss'] = total_loss
diff --git a/bob/learn/tensorflow/test/test_onegraph_model_fn.py b/bob/learn/tensorflow/test/test_onegraph_model_fn.py
index fadf49dc9725ce24956f8e4d31c2577922b3ec2e..2b73e88a5035bc784777ac7787b2240ddf4387e0 100755
--- a/bob/learn/tensorflow/test/test_onegraph_model_fn.py
+++ b/bob/learn/tensorflow/test/test_onegraph_model_fn.py
@@ -2,23 +2,15 @@
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
 
-import numpy
-from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, scale_factor
-from bob.learn.tensorflow.network import chopra
-from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss, triplet_loss
-from bob.learn.tensorflow.test.test_cnn_scratch import validate_network
-from bob.learn.tensorflow.network import dummy
-from bob.learn.tensorflow.network.utils import append_logits
-
-
 import tensorflow as tf
 
-
-from bob.learn.tensorflow.trainers import LogitsTrainer
+from bob.learn.tensorflow.network import dummy
+from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer
 from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels, batch_data_and_labels
 from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
 from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
-
+from bob.learn.tensorflow.loss import mean_cross_entropy_loss
+import numpy
 
 import shutil
 import os
@@ -37,16 +29,25 @@ epochs = 1
 steps = 2000
 
 
-def test_cnn_trainer():
-    run_cnn(False)
+def test_logitstrainer():
+    run_logitstrainer(False)
+
+
+def test_logitstrainer_embedding():
+    run_logitstrainer(True)
+
 
+def test_logitstrainer_centerloss():
+    run_logitstrainer_centerloss(False)
 
-def test_cnn_trainer_embedding():
-    run_cnn(True)
 
+def test_logitstrainer_centerloss_embedding():
+    run_logitstrainer_centerloss(True)
 
-def run_cnn(embedding_validation):
 
+def run_logitstrainer(embedding_validation):
+
+    # Cleaning up
     tf.reset_default_graph()
     assert len(tf.global_variables()) == 0
 
@@ -79,6 +80,7 @@ def run_cnn(embedding_validation):
         trainer.train(input_fn, steps=steps, hooks=hooks)
         
         if not embedding_validation:
+
             acc = trainer.evaluate(input_fn_validation)
             assert acc['accuracy'] > 0.80
         else:
@@ -91,10 +93,70 @@ def run_cnn(embedding_validation):
             os.unlink(tfrecord_validation)            
             shutil.rmtree(model_dir)
         except Exception:
-            pass        
+            pass
+
+    # Cleaning up
+    tf.reset_default_graph()
+    assert len(tf.global_variables()) == 0
+
+
+def run_logitstrainer_centerloss(embedding_validation):
+
+    # Cleaning up
+    tf.reset_default_graph()
+    assert len(tf.global_variables()) == 0
+
+    # Creating tf records for mnist
+    train_data, train_labels, validation_data, validation_labels = load_mnist()
+    create_mnist_tfrecord(tfrecord_train, train_data, train_labels, n_samples=6000)
+    create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=1000)
+
+    try:
 
+        # Trainer logits
+        trainer = LogitsCenterLossTrainer(
+                                model_dir=model_dir,
+                                architecture=dummy,
+                                optimizer=tf.train.GradientDescentOptimizer(learning_rate),
+                                n_classes=10,
+                                embedding_validation=embedding_validation,
+                                validation_batch_size=validation_batch_size,
+                                factor=0.01
+                                )
+
+        def input_fn():
+            return shuffle_data_and_labels(tfrecord_train, data_shape, data_type,
+                                           batch_size, epochs=epochs)
 
+        def input_fn_validation():
+            return batch_data_and_labels(tfrecord_validation, data_shape, data_type,
+                                         validation_batch_size, epochs=epochs)
+
+        hooks = [LoggerHookEstimator(trainer, 16, 100)]
+        trainer.train(input_fn, steps=steps, hooks=hooks)
 
+        if not embedding_validation:
+            acc = trainer.evaluate(input_fn_validation)
+            assert acc['accuracy'] > 0.80
+        else:
+            acc = trainer.evaluate(input_fn_validation)
+            assert acc['accuracy'] > 0.80
 
+        sess = tf.Session()
+        checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path
+        saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True)
+        saver.restore(sess, tf.train.latest_checkpoint(model_dir))
+        centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0]
+        assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0
 
+    finally:
+        try:
+            os.unlink(tfrecord_train)
+            os.unlink(tfrecord_validation)
+            shutil.rmtree(model_dir)
+        except Exception:
+            pass
 
+    # Cleaning up
+    tf.reset_default_graph()
+    assert len(tf.global_variables()) == 0
diff --git a/bob/learn/tensorflow/trainers/LogitsTrainer.py b/bob/learn/tensorflow/trainers/LogitsTrainer.py
index 7204396701590e98779e61d0f88b217641c8cf62..a7d12b0145add029fdd2836fd98998a8cbea784a 100755
--- a/bob/learn/tensorflow/trainers/LogitsTrainer.py
+++ b/bob/learn/tensorflow/trainers/LogitsTrainer.py
@@ -14,8 +14,8 @@ import time
 #logger = bob.core.log.setup("bob.learn.tensorflow")
 from bob.learn.tensorflow.network.utils import append_logits
 from tensorflow.python.estimator import estimator
-from bob.learn.tensorflow.utils import reproducible
 from bob.learn.tensorflow.utils import predict_using_tensors
+from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
 
 
 import logging
@@ -24,24 +24,41 @@ logger = logging.getLogger("bob.learn")
 
 class LogitsTrainer(estimator.Estimator):
     """
-    Logits .
-     
+    NN Trainer whose with logits as last layer
+
+    The **architecture** function should follow the following pattern:
+
+      def my_beautiful_function(placeholder):
+
+          end_points = dict()
+          graph = convXX(placeholder)
+          end_points['conv'] = graph
+          ....
+          return graph, end_points
+
+    The **loss** function should follow the following pattern:
+
+    def my_beautiful_loss(logits, labels):
+       return loss_set_of_ops(logits, labels)
+
+
     **Parameters**
       architecture:
          Pointer to a function that builds the graph.
-         The signature should be something like `my_beautiful_function(input)`
 
       optimizer:
          One of the tensorflow solvers (https://www.tensorflow.org/api_guides/python/train)
+         - tf.train.GradientDescentOptimizer
+         - tf.train.AdagradOptimizer
+         - ....
          
       config:
          
       n_classes:
-         Number of classes of your problem
+         Number of classes of your problem. The logits will be appended in this class
          
       loss_op:
          Pointer to a function that computes the loss.
-         The signature should be something like `my_beautiful_loss(logits, labels)`
       
       embedding_validation:
          Run the validation using embeddings?? [default: False]
@@ -81,7 +98,7 @@ class LogitsTrainer(estimator.Estimator):
         if self.loss_op is None:
             raise ValueError("Please specify a function to build the loss !!")
 
-        if self.n_classes <=0:
+        if self.n_classes <= 0:
             raise ValueError("Number of classes must be greated than 0")
 
         def _model_fn(features, labels, mode, params, config):
@@ -109,7 +126,7 @@ class LogitsTrainer(estimator.Estimator):
             if mode == tf.estimator.ModeKeys.PREDICT:
                 return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
 
-            # Calculate Loss (for both TRAIN and EVAL modes)
+            # Compute Loss (for both TRAIN and EVAL modes)
             self.loss = self.loss_op(logits, labels)
 
             # Configure the Training Op (for TRAIN mode)
@@ -136,3 +153,134 @@ class LogitsTrainer(estimator.Estimator):
                                             model_dir=model_dir,
                                             config=config)
 
+
+class LogitsCenterLossTrainer(estimator.Estimator):
+    """
+    NN Trainer whose with logits as last layer
+
+    The **architecture** function should follow the following pattern:
+
+      def my_beautiful_function(placeholder):
+
+          end_points = dict()
+          graph = convXX(placeholder)
+          end_points['conv'] = graph
+          ....
+          return graph, end_points
+
+    **Parameters**
+      architecture:
+         Pointer to a function that builds the graph.
+
+      optimizer:
+         One of the tensorflow solvers (https://www.tensorflow.org/api_guides/python/train)
+         - tf.train.GradientDescentOptimizer
+         - tf.train.AdagradOptimizer
+         - ....
+
+      config:
+
+      n_classes:
+         Number of classes of your problem. The logits will be appended in this class
+
+      loss_op:
+         Pointer to a function that computes the loss.
+
+      embedding_validation:
+         Run the validation using embeddings?? [default: False]
+
+      model_dir:
+        Model path
+
+      validation_batch_size:
+        Size of the batch for validation. This value is used when the
+        validation with embeddings is used. This is a hack.
+    """
+
+    def __init__(self,
+                 architecture=None,
+                 optimizer=None,
+                 config=None,
+                 n_classes=0,
+                 embedding_validation=False,
+                 model_dir="",
+                 alpha=0.9,
+                 factor=0.01,
+                 validation_batch_size=None,
+              ):
+
+        self.architecture = architecture
+        self.optimizer = optimizer
+        self.n_classes = n_classes
+        self.alpha = alpha
+        self.factor = factor
+        self.loss = None
+        self.embedding_validation = embedding_validation
+
+        if self.architecture is None:
+            raise ValueError("Please specify a function to build the architecture !!")
+
+        if self.optimizer is None:
+            raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!")
+
+        if self.n_classes <= 0:
+            raise ValueError("Number of classes must be greated than 0")
+
+        def _model_fn(features, labels, mode, params, config):
+
+            # Building one graph
+            prelogits = self.architecture(features)[0]
+            logits = append_logits(prelogits, n_classes)
+
+            if self.embedding_validation:
+                # Compute the embeddings
+                embeddings = tf.nn.l2_normalize(prelogits, 1)
+                predictions = {
+                    "embeddings": embeddings
+                }
+
+            else:
+                predictions = {
+                    # Generate predictions (for PREDICT and EVAL mode)
+                    "classes": tf.argmax(input=logits, axis=1),
+                    # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
+                    # `logging_hook`.
+                    "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
+
+                }
+
+            if mode == tf.estimator.ModeKeys.PREDICT:
+                return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
+
+            # Compute Loss (for both TRAIN and EVAL modes)
+            loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
+                                                       alpha=self.alpha, factor=self.factor)
+            self.loss = loss_dict['loss']
+            centers = loss_dict['centers']
+
+            # Configure the Training Op (for TRAIN mode)
+            if mode == tf.estimator.ModeKeys.TRAIN:
+                global_step = tf.contrib.framework.get_or_create_global_step()
+                # backprop and updating the centers
+                train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
+                                    centers)
+
+                return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
+                                                  train_op=train_op)
+
+            if self.embedding_validation:
+                predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
+                eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
+                return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
+
+            else:
+                # Add evaluation metrics (for EVAL mode)
+                eval_metric_ops = {
+                    "accuracy": tf.metrics.accuracy(
+                        labels=labels, predictions=predictions["classes"])}
+                return tf.estimator.EstimatorSpec(
+                    mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
+
+        super(LogitsCenterLossTrainer, self).__init__(model_fn=_model_fn,
+                                                      model_dir=model_dir,
+                                                      config=config)
diff --git a/bob/learn/tensorflow/trainers/__init__.py b/bob/learn/tensorflow/trainers/__init__.py
index 6f22f211e205f0f4a7a5605db261382673ef4025..7812e271417bf040e33716472f564ee9f187b349 100755
--- a/bob/learn/tensorflow/trainers/__init__.py
+++ b/bob/learn/tensorflow/trainers/__init__.py
@@ -3,7 +3,7 @@ from .Trainer import Trainer
 from .SiameseTrainer import SiameseTrainer
 from .TripletTrainer import TripletTrainer
 from .learning_rate import exponential_decay, constant
-from .LogitsTrainer import LogitsTrainer
+from .LogitsTrainer import LogitsTrainer, LogitsCenterLossTrainer
 import numpy