diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py
index 30d31c5630fcff0a6860e646f4989d201aa73719..b7d614b4e98a802166fa490ace1dd5818afb917d 100755
--- a/bob/learn/tensorflow/dataset/__init__.py
+++ b/bob/learn/tensorflow/dataset/__init__.py
@@ -70,6 +70,67 @@ def append_image_augmentation(image, gray_scale=False,
     return image
 
 
+def arrange_indexes_by_label(input_labels, possible_labels):
+
+    # Shuffling all the indexes
+    indexes_per_labels = dict()
+    for l in possible_labels:
+        indexes_per_labels[l] = numpy.where(input_labels == l)[0]
+        numpy.random.shuffle(indexes_per_labels[l])
+    return indexes_per_labels
+
+
+def triplets_random_generator(input_data, input_labels):
+    """
+    Giving a list of samples and a list of labels, it dumps a series of
+    triplets for triple nets.
+
+    **Parameters**
+
+      input_data: List of whatever representing the data samples
+
+      input_labels: List of the labels (needs to be in EXACT same order as input_data)
+    """
+    anchor = []
+    positive = []
+    negative = []
+
+    def append(anchor_sample, positive_sample, negative_sample):
+        """
+        Just appending one element in each list
+        """
+        anchor.append(anchor_sample)
+        positive.append(positive_sample)
+        negative.append(negative_sample)
+
+    possible_labels = list(set(input_labels))
+    input_data = numpy.array(input_data)
+    input_labels = numpy.array(input_labels)
+    total_samples = input_data.shape[0]
+
+    indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
+
+    # searching for random triplets
+    offset_class = 0
+    for i in range(total_samples):
+
+        anchor_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
+
+        positive_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
+
+        # Changing the class
+        offset_class += 1
+
+        if offset_class == len(possible_labels):
+            offset_class = 0
+
+        negative_sample = input_data[indexes_per_labels[possible_labels[offset_class]][numpy.random.randint(len(indexes_per_labels[possible_labels[offset_class]]))], ...]
+
+        append(str(anchor_sample), str(positive_sample), str(negative_sample))
+        #yield anchor, positive, negative
+    return anchor, positive, negative
+
+
 def siamease_pairs_generator(input_data, input_labels):
     """
     Giving a list of samples and a list of labels, it dumps a series of
@@ -101,10 +162,11 @@ def siamease_pairs_generator(input_data, input_labels):
     total_samples = input_data.shape[0]
 
     # Filtering the samples by label and shuffling all the indexes
-    indexes_per_labels = dict()
-    for l in possible_labels:
-        indexes_per_labels[l] = numpy.where(input_labels == l)[0]
-        numpy.random.shuffle(indexes_per_labels[l])
+    #indexes_per_labels = dict()
+    #for l in possible_labels:
+    #    indexes_per_labels[l] = numpy.where(input_labels == l)[0]
+    #    numpy.random.shuffle(indexes_per_labels[l])
+    indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
 
     left_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
     right_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
diff --git a/bob/learn/tensorflow/dataset/image.py b/bob/learn/tensorflow/dataset/image.py
index 0c5930500b9d747ca5566fad9a93bd6fbd11f6b9..b35b0c805f65e832ab795278cb1007cadd7b6962 100644
--- a/bob/learn/tensorflow/dataset/image.py
+++ b/bob/learn/tensorflow/dataset/image.py
@@ -161,6 +161,9 @@ def image_augmentation_parser(filename, label, data_shape, data_type,
                                       per_image_normalization=per_image_normalization)
                                         
     label = tf.cast(label, tf.int64)
+    features = dict()
+    features['data'] = image
+    features['key'] = filename
 
-    return image, label
+    return features, label
 
diff --git a/bob/learn/tensorflow/dataset/siamese_image.py b/bob/learn/tensorflow/dataset/siamese_image.py
index d22e036051e99d00f06e309d7268050690f60d7c..bc1d48a2b631947ef386d326abb3b6d05f6659f4 100644
--- a/bob/learn/tensorflow/dataset/siamese_image.py
+++ b/bob/learn/tensorflow/dataset/siamese_image.py
@@ -181,7 +181,6 @@ def image_augmentation_parser(filename_left, filename_right, label, data_shape,
     image = dict()
     image['left']  = image_left
     image['right'] = image_right
-
     label = tf.cast(label, tf.int64)
 
     return image, label
diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index f3458ba9271b57f2afbf5376acc61ec3e859d7af..84c21ab9f822bfe5311a7f1c2a44e7924daee048 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -35,7 +35,7 @@ def image_augmentation_parser(serialized_example, feature, data_shape, data_type
     # Decode the record read by the reader
     features = tf.parse_single_example(serialized_example, features=feature)
     # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features['train/data'], data_type)
+    image = tf.decode_raw(features['data'], data_type)
 
     # Reshape image data into the original shape
     image = tf.reshape(image, data_shape)
@@ -50,8 +50,10 @@ def image_augmentation_parser(serialized_example, feature, data_shape, data_type
                                       per_image_normalization=per_image_normalization)
 
     # Cast label data into int64
-    label = tf.cast(features['train/label'], tf.int64)
-    return image, label
+    label = tf.cast(features['label'], tf.int64)
+    key = tf.cast(features['key'], tf.string)
+    
+    return image, label, key
 
 
 def read_and_decode(filename_queue, data_shape, data_type=tf.float32,
@@ -208,8 +210,13 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, d
 
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
 
-    data, labels = dataset.make_one_shot_iterator().get_next()
-    return data, labels
+    data, labels, key = dataset.make_one_shot_iterator().get_next()
+    
+    features = dict()
+    features['data'] = data
+    features['key'] = key
+    
+    return features, labels
 
 
 def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type,
@@ -243,8 +250,12 @@ def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type,
                                           data_type)
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
 
-    data, labels = dataset.make_one_shot_iterator().get_next()
-    return data, labels
+    data, labels, key = dataset.make_one_shot_iterator().get_next()
+    features = dict()
+    features['data'] = data
+    features['key'] = key
+    
+    return features, labels
 
 
 def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
@@ -274,5 +285,10 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
                                           data_type)
     dataset = dataset.batch(batch_size).repeat(epochs)
 
-    data, labels = dataset.make_one_shot_iterator().get_next()
-    return data, labels
+    data, labels, key = dataset.make_one_shot_iterator().get_next()
+    features = dict()
+    features['data'] = data
+    features['key'] = key
+    
+    return features, labels
+
diff --git a/bob/learn/tensorflow/dataset/triplet_image.py b/bob/learn/tensorflow/dataset/triplet_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..842ecfd74e128fbb12734ff5acfe7c0044f18324
--- /dev/null
+++ b/bob/learn/tensorflow/dataset/triplet_image.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+
+import tensorflow as tf
+from functools import partial
+from . import append_image_augmentation, triplets_random_generator
+
+
+def shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type,
+                                               batch_size, epochs=None, buffer_size=10**3,
+                                               gray_scale=False,
+                                               output_shape=None,
+                                               random_flip=False,
+                                               random_brightness=False,
+                                               random_contrast=False,
+                                               random_saturation=False,
+                                               per_image_normalization=True):
+    """
+    Dump random batches for triplee networks from a list of image paths and labels:
+        
+    The list of files and labels should be in the same order e.g.
+    filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1']
+    labels = [0, 0, 1]
+    
+    The batches returned with tf.Session.run() with be in the following format:
+    **data** a dictionary containing the keys ['anchor', 'positive', 'negative'].
+    
+
+    **Parameters**
+
+       filenames:
+          List containing the path of the images
+       
+       labels:
+          List containing the labels (needs to be in EXACT same order as filenames)
+          
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     
+       batch_size:
+          Size of the batch
+          
+       epochs:
+           Number of epochs to be batched
+       
+       buffer_size:
+            Size of the shuffle bucket
+
+       gray_scale:
+          Convert to gray scale?
+          
+       output_shape:
+          If set, will randomly crop the image given the output shape
+
+       random_flip:
+          Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
+
+       random_brightness:
+           Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
+
+       random_contrast:
+           Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
+
+       random_saturation:
+           Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
+
+       per_image_normalization:
+           Linearly scales image to have zero mean and unit norm.            
+
+     
+    """                            
+
+    dataset = create_dataset_from_path_augmentation(filenames, labels, data_shape,
+                                                    data_type,
+                                                    gray_scale=gray_scale,
+                                                    output_shape=output_shape,
+                                                    random_flip=random_flip,
+                                                    random_brightness=random_brightness,
+                                                    random_contrast=random_contrast,
+                                                    random_saturation=random_saturation,
+                                                    per_image_normalization=per_image_normalization)
+
+    dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
+    #dataset = dataset.batch(buffer_size).batch(batch_size).repeat(epochs)
+
+    data = dataset.make_one_shot_iterator().get_next()
+    return data
+
+
+def create_dataset_from_path_augmentation(filenames, labels,
+                                          data_shape, data_type=tf.float32,
+                                          gray_scale=False, 
+                                          output_shape=None,
+                                          random_flip=False,
+                                          random_brightness=False,
+                                          random_contrast=False,
+                                          random_saturation=False,
+                                          per_image_normalization=True):
+    """
+    Create dataset from a list of tf-record files
+    
+    **Parameters**
+    
+       filenames:
+          List containing the path of the images
+       
+       labels:
+          List containing the labels (needs to be in EXACT same order as filenames)
+          
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+          
+       feature:
+    
+    """
+ 
+    parser = partial(image_augmentation_parser,
+                     data_shape=data_shape,
+                     data_type=data_type,
+                     gray_scale=gray_scale, 
+                     output_shape=output_shape,
+                     random_flip=random_flip,
+                     random_brightness=random_brightness,
+                     random_contrast=random_contrast,
+                     random_saturation=random_saturation,
+                     per_image_normalization=per_image_normalization) 
+
+    anchor_data, positive_data, negative_data = triplets_random_generator(filenames, labels)
+
+    dataset = tf.contrib.data.Dataset.from_tensor_slices((anchor_data, positive_data, negative_data))
+    dataset = dataset.map(parser)
+    return dataset
+
+
+def image_augmentation_parser(anchor, positive, negative, data_shape, data_type=tf.float32,
+                              gray_scale=False, 
+                              output_shape=None,
+                              random_flip=False,
+                              random_brightness=False,
+                              random_contrast=False,
+                              random_saturation=False,
+                              per_image_normalization=True):
+
+    """
+    Parses a single tf.Example into image and label tensors.
+    """
+
+    triplet = dict()
+    for n, v in zip(['anchor', 'positive', 'negative'], [anchor, positive, negative]):
+
+        # Convert the image data from string back to the numbers
+        image = tf.cast(tf.image.decode_image(tf.read_file(v)), data_type)
+
+        # Reshape image data into the original shape
+        image = tf.reshape(image, data_shape)
+
+        # Applying image augmentation
+        image = append_image_augmentation(image, gray_scale=gray_scale,
+                                          output_shape=output_shape,
+                                          random_flip=random_flip,
+                                          random_brightness=random_brightness,
+                                          random_contrast=random_contrast,
+                                          random_saturation=random_saturation,
+                                          per_image_normalization=per_image_normalization)
+
+        triplet[n] = image
+
+    return triplet
diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py
index 5c582f82d57def58da29ae6a2300b362d1d78b70..d1b812cfd40fb00e417af62bfcd75e82e230283f 100755
--- a/bob/learn/tensorflow/estimators/Logits.py
+++ b/bob/learn/tensorflow/estimators/Logits.py
@@ -16,6 +16,7 @@ from bob.learn.tensorflow.network.utils import append_logits
 from tensorflow.python.estimator import estimator
 from bob.learn.tensorflow.utils import predict_using_tensors
 from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
+from . import check_features
 
 
 import logging
@@ -102,9 +103,13 @@ class Logits(estimator.Estimator):
             raise ValueError("Number of classes must be greated than 0")
 
         def _model_fn(features, labels, mode, params, config):
+
+            check_features(features)
+            data = features['data']
+            key = features['key']
             
             # Building one graph
-            prelogits = self.architecture(features)[0]
+            prelogits = self.architecture(data)[0]
             logits = append_logits(prelogits, n_classes)
 
             if self.embedding_validation:
@@ -228,8 +233,12 @@ class LogitsCenterLoss(estimator.Estimator):
 
         def _model_fn(features, labels, mode, params, config):
 
+            check_features(features)
+            data = features['data']
+            key = features['key']
+
             # Building one graph
-            prelogits = self.architecture(features)[0]
+            prelogits = self.architecture(data)[0]
             logits = append_logits(prelogits, n_classes)
 
             if self.embedding_validation:
@@ -284,3 +293,4 @@ class LogitsCenterLoss(estimator.Estimator):
         super(LogitsCenterLoss, self).__init__(model_fn=_model_fn,
                                                model_dir=model_dir,
                                                config=config)
+                                               
diff --git a/bob/learn/tensorflow/estimators/Siamese.py b/bob/learn/tensorflow/estimators/Siamese.py
index f36b6b449fe1fb49d2a6d476008e7f38b7477416..ca13261bf1b59a2de06b250d9fdc99d8b9ca8fcc 100755
--- a/bob/learn/tensorflow/estimators/Siamese.py
+++ b/bob/learn/tensorflow/estimators/Siamese.py
@@ -13,7 +13,7 @@ import time
 from tensorflow.python.estimator import estimator
 from bob.learn.tensorflow.utils import predict_using_tensors
 #from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
-
+from . import check_features
 
 import logging
 logger = logging.getLogger("bob.learn")
@@ -116,8 +116,11 @@ class Siamese(estimator.Estimator):
                 return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
                                                   train_op=train_op)
 
+            check_features(features)
+            data = features['data']
+
             # Compute the embeddings
-            prelogits = self.architecture(features)[0]
+            prelogits = self.architecture(data)[0]
             embeddings = tf.nn.l2_normalize(prelogits, 1)
             predictions = {"embeddings": embeddings}
 
diff --git a/bob/learn/tensorflow/estimators/Triplet.py b/bob/learn/tensorflow/estimators/Triplet.py
new file mode 100755
index 0000000000000000000000000000000000000000..021ce9fe1ff87453f4b8d557ecfc2db09df0c324
--- /dev/null
+++ b/bob/learn/tensorflow/estimators/Triplet.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+
+import tensorflow as tf
+import os
+import bob.io.base
+import bob.core
+from tensorflow.core.framework import summary_pb2
+import time
+
+#logger = bob.core.log.setup("bob.learn.tensorflow")
+from tensorflow.python.estimator import estimator
+from bob.learn.tensorflow.utils import predict_using_tensors
+from bob.learn.tensorflow.loss import triplet_loss
+from . import check_features
+
+
+import logging
+logger = logging.getLogger("bob.learn")
+
+
+class Triplet(estimator.Estimator):
+    """
+    NN estimator for Triplet networks
+
+    Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
+    "Facenet: A unified embedding for face recognition and clustering." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
+
+    The **architecture** function should follow the following pattern:
+
+      def my_beautiful_function(placeholder):
+
+          end_points = dict()
+          graph = convXX(placeholder)
+          end_points['conv'] = graph
+          ....
+          return graph, end_points
+
+    The **loss** function should follow the following pattern:
+
+    def my_beautiful_loss(logits, labels):
+       return loss_set_of_ops(logits, labels)
+
+
+    **Parameters**
+      architecture:
+         Pointer to a function that builds the graph.
+
+      optimizer:
+         One of the tensorflow solvers (https://www.tensorflow.org/api_guides/python/train)
+         - tf.train.GradientDescentOptimizer
+         - tf.train.AdagradOptimizer
+         - ....
+         
+      config:
+         
+      n_classes:
+         Number of classes of your problem. The logits will be appended in this class
+         
+      loss_op:
+         Pointer to a function that computes the loss.
+      
+      embedding_validation:
+         Run the validation using embeddings?? [default: False]
+      
+      model_dir:
+        Model path
+
+      validation_batch_size:
+        Size of the batch for validation. This value is used when the
+        validation with embeddings is used. This is a hack.
+    """
+
+    def __init__(self,
+                 architecture=None,
+                 optimizer=None,
+                 config=None,
+                 n_classes=0,
+                 loss_op=triplet_loss,
+                 model_dir="",
+                 validation_batch_size=None,
+              ):
+
+        self.architecture = architecture
+        self.optimizer=optimizer
+        self.n_classes=n_classes
+        self.loss_op=loss_op
+        self.loss = None
+
+        if self.architecture is None:
+            raise ValueError("Please specify a function to build the architecture !!")
+            
+        if self.optimizer is None:
+            raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!")
+
+        if self.loss_op is None:
+            raise ValueError("Please specify a function to build the loss !!")
+
+        if self.n_classes <= 0:
+            raise ValueError("Number of classes must be greated than 0")
+
+        def _model_fn(features, labels, mode, params, config):
+
+            if mode == tf.estimator.ModeKeys.TRAIN:
+
+                # The input function needs to have dictionary pair with the `left` and `right` keys
+                if not 'anchor' in features.keys() or not \
+                                'positive' in features.keys() or not \
+                                'negative' in features.keys():
+                    raise ValueError("The input function needs to contain a dictionary with the "
+                                     "keys `anchor`, `positive` and `negative` ")
+            
+                # Building one graph
+                prelogits_anchor = self.architecture(features['anchor'])[0]
+                prelogits_positive = self.architecture(features['positive'], reuse=True)[0]
+                prelogits_negative = self.architecture(features['negative'], reuse=True)[0]
+
+                # Compute Loss (for both TRAIN and EVAL modes)
+                self.loss = self.loss_op(prelogits_anchor, prelogits_positive, prelogits_negative)
+                # Configure the Training Op (for TRAIN mode)
+                global_step = tf.contrib.framework.get_or_create_global_step()
+                train_op = self.optimizer.minimize(self.loss, global_step=global_step)
+                return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
+                                                  train_op=train_op)
+
+            check_features(features)
+            data = features['data']
+
+            # Compute the embeddings
+            prelogits = self.architecture(data)[0]
+            embeddings = tf.nn.l2_normalize(prelogits, 1)
+            predictions = {"embeddings": embeddings}
+
+            if mode == tf.estimator.ModeKeys.PREDICT:
+                return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
+
+            predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
+            eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
+            
+            return tf.estimator.EstimatorSpec(mode=mode, loss=tf.reduce_mean(1), eval_metric_ops=eval_metric_ops)
+
+        super(Triplet, self).__init__(model_fn=_model_fn,
+                                      model_dir=model_dir,
+                                      config=config)
+
diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py
index e63d63299fa4d33280d63d33d493034d8e0cfb9e..c1036835557a3555adbb8511bb33cbea106302d3 100755
--- a/bob/learn/tensorflow/estimators/__init__.py
+++ b/bob/learn/tensorflow/estimators/__init__.py
@@ -1,5 +1,17 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+
+def check_features(features):
+    if not 'data' in features.keys() or not 'key' in features.keys():
+        raise ValueError("The input function needs to contain a dictionary with the keys `data` and `key` ")
+    return True
+
+
 from .Logits import Logits, LogitsCenterLoss
 from .Siamese import Siamese
+from .Triplet import Triplet
+
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
@@ -17,7 +29,9 @@ def __appropriate__(*args):
 
 __appropriate__(
     Logits,
-    LogitsCenterLoss
+    LogitsCenterLoss,
+    Siamese,
+    Triplet
     )
 __all__ = [_ for _ in dir() if not _.startswith('_')]
 
diff --git a/bob/learn/tensorflow/loss/TripletLoss.py b/bob/learn/tensorflow/loss/TripletLoss.py
index c642507cb8dd1f6fc9ee9cab62c4a95c1166de88..1c788cbc837be4ad46db7d81aaef6540b46ebe86 100755
--- a/bob/learn/tensorflow/loss/TripletLoss.py
+++ b/bob/learn/tensorflow/loss/TripletLoss.py
@@ -35,6 +35,53 @@ def triplet_loss(anchor_embedding, positive_embedding, negative_embedding, margi
 
     """
 
+    with tf.name_scope("triplet_loss"):
+        # Normalize
+        anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
+        positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive")
+        negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative")
+
+        d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, positive_embedding)), 1)
+        d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, negative_embedding)), 1)
+
+        basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin)
+        loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0, name=tf.GraphKeys.LOSSES)
+        between_class_loss = tf.reduce_mean(d_negative)
+        within_class_loss = tf.reduce_mean(d_positive)
+
+        tf.summary.scalar('loss', loss)
+        tf.summary.scalar('between_class', between_class_loss)
+        tf.summary.scalar('within_class', within_class_loss)
+
+        return loss
+
+
+def triplet_loss_deprecated(anchor_embedding, positive_embedding, negative_embedding, margin=5.0):
+    """
+    Compute the triplet loss as in
+
+    Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
+    "Facenet: A unified embedding for face recognition and clustering."
+    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
+
+    :math:`L  = sum(  |f_a - f_p|^2 - |f_a - f_n|^2  + \lambda)`
+
+    **Parameters**
+
+    left_feature:
+      First element of the pair
+
+    right_feature:
+      Second element of the pair
+
+    label:
+      Label of the pair (0 or 1)
+
+    margin:
+      Contrastive margin
+
+    """
+
     with tf.name_scope("triplet_loss"):
         # Normalize
         anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
@@ -53,7 +100,7 @@ def triplet_loss(anchor_embedding, positive_embedding, negative_embedding, margi
         loss_dict['within_class'] = tf.reduce_mean(d_positive)
 
         return loss_dict
-        
+
 
 def triplet_fisher_loss(anchor_embedding, positive_embedding, negative_embedding):
 
@@ -141,4 +188,47 @@ def triplet_average_loss(anchor_embedding, positive_embedding, negative_embeddin
 
         return loss, tf.reduce_mean(d_negative), tf.reduce_mean(d_positive)        
 
-        
+
+def triplet_fisher_loss(anchor_embedding, positive_embedding, negative_embedding):
+
+    with tf.name_scope("triplet_loss"):
+        # Normalize
+        anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
+        positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive")
+        negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative")
+
+        average_class = tf.reduce_mean(anchor_embedding, 0)
+        average_total = tf.div(tf.add(tf.reduce_mean(anchor_embedding, axis=0),\
+                        tf.reduce_mean(negative_embedding, axis=0)), 2)
+
+        length = anchor_embedding.get_shape().as_list()[0]
+        dim = anchor_embedding.get_shape().as_list()[1]
+        split_positive = tf.unstack(positive_embedding, num=length, axis=0)
+        split_negative = tf.unstack(negative_embedding, num=length, axis=0)
+
+        Sw = None
+        Sb = None
+        for s in zip(split_positive, split_negative):
+            positive = s[0]
+            negative = s[1]
+
+            buffer_sw = tf.reshape(tf.subtract(positive, average_class), shape=(dim, 1))
+            buffer_sw = tf.matmul(buffer_sw, tf.reshape(buffer_sw, shape=(1, dim)))
+
+            buffer_sb = tf.reshape(tf.subtract(negative, average_total), shape=(dim, 1))
+            buffer_sb = tf.matmul(buffer_sb, tf.reshape(buffer_sb, shape=(1, dim)))
+
+            if Sw is None:
+                Sw = buffer_sw
+                Sb = buffer_sb
+            else:
+                Sw = tf.add(Sw, buffer_sw)
+                Sb = tf.add(Sb, buffer_sb)
+
+        # Sw = tf.trace(Sw)
+        # Sb = tf.trace(Sb)
+        #loss = tf.trace(tf.div(Sb, Sw))
+        loss = tf.trace(tf.div(Sw, Sb), name=tf.GraphKeys.LOSSES)
+
+        return loss, tf.trace(Sb), tf.trace(Sw)
+
diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py
index c1e327758a2f9e30d1bbd365caf7f748270db3a3..94b4f281328d39549cf06b42839fb82f09c115b0 100755
--- a/bob/learn/tensorflow/loss/__init__.py
+++ b/bob/learn/tensorflow/loss/__init__.py
@@ -1,6 +1,6 @@
 from .BaseLoss import mean_cross_entropy_loss, mean_cross_entropy_center_loss
 from .ContrastiveLoss import contrastive_loss, contrastive_loss_deprecated
-from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
+from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss, triplet_loss_deprecated
 #from .NegLogLoss import NegLogLoss
 
 
diff --git a/bob/learn/tensorflow/test/test_cnn.py b/bob/learn/tensorflow/test/test_cnn.py
index 21bf3049f1d4e1ecee7f7dd40cbb021127f217b5..aef4b92cd0464f4e59b28309d35eac7f31cf2be6 100755
--- a/bob/learn/tensorflow/test/test_cnn.py
+++ b/bob/learn/tensorflow/test/test_cnn.py
@@ -6,7 +6,7 @@
 import numpy
 from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, scale_factor
 from bob.learn.tensorflow.network import dummy
-from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss_deprecated, triplet_loss
+from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss_deprecated, triplet_loss_deprecated
 from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant
 from bob.learn.tensorflow.test.test_cnn_scratch import validate_network
 from bob.learn.tensorflow.network import Embedding, light_cnn9
@@ -256,7 +256,7 @@ def test_tripletcnn_trainer():
     graph['positive'] = dummy(inputs['positive'], reuse=True)[0]
     graph['negative'] = dummy(inputs['negative'], reuse=True)[0]
 
-    loss = triplet_loss(graph['anchor'], graph['positive'], graph['negative'])
+    loss = triplet_loss_deprecated(graph['anchor'], graph['positive'], graph['negative'])
 
     # One graph trainer
     trainer = TripletTrainer(train_data_shuffler,
diff --git a/bob/learn/tensorflow/test/test_cnn_pretrained_model.py b/bob/learn/tensorflow/test/test_cnn_pretrained_model.py
index 77b9de560c4d87e0bd19f49a46867e30bb77962d..1124ceb1bffa1c69eb7409b3caa64e041b886e8c 100755
--- a/bob/learn/tensorflow/test/test_cnn_pretrained_model.py
+++ b/bob/learn/tensorflow/test/test_cnn_pretrained_model.py
@@ -7,7 +7,7 @@ import numpy
 import bob.io.base
 import os
 from bob.learn.tensorflow.datashuffler import Memory, TripletMemory, SiameseMemory, scale_factor
-from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss_deprecated, triplet_loss
+from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss_deprecated, triplet_loss_deprecated
 from bob.learn.tensorflow.trainers import Trainer, constant, TripletTrainer, SiameseTrainer
 from bob.learn.tensorflow.utils import load_mnist
 from bob.learn.tensorflow.network import Embedding
@@ -137,7 +137,7 @@ def test_triplet_cnn_pretrained():
     graph['negative'] = scratch_network(inputs['negative'], reuse=True)
 
     # Loss for the softmax
-    loss = triplet_loss(graph['anchor'], graph['positive'], graph['negative'], margin=4.)
+    loss = triplet_loss_deprecated(graph['anchor'], graph['positive'], graph['negative'], margin=4.)
 
     # One graph trainer
     trainer = TripletTrainer(train_data_shuffler,
diff --git a/bob/learn/tensorflow/test/test_cnn_scratch.py b/bob/learn/tensorflow/test/test_cnn_scratch.py
index 5e97d8835eefc4339de627de2593d1d40156a80d..edc60426e34198f9a6e2a48f2e59a0c7aec12d65 100755
--- a/bob/learn/tensorflow/test/test_cnn_scratch.py
+++ b/bob/learn/tensorflow/test/test_cnn_scratch.py
@@ -6,7 +6,7 @@
 import numpy
 from bob.learn.tensorflow.datashuffler import Memory, scale_factor, TFRecord
 from bob.learn.tensorflow.network import Embedding
-from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss_deprecated, triplet_loss
+from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss_deprecated
 from bob.learn.tensorflow.trainers import Trainer, constant
 from bob.learn.tensorflow.utils import load_mnist
 import tensorflow as tf
diff --git a/bob/learn/tensorflow/test/test_dataset.py b/bob/learn/tensorflow/test/test_dataset.py
index 4a782bb72bab98a914b95c6a4f2e513a67332d37..31a0ceae4e8bde01b5bdc1b442bc5be774ae2751 100755
--- a/bob/learn/tensorflow/test/test_dataset.py
+++ b/bob/learn/tensorflow/test/test_dataset.py
@@ -5,6 +5,7 @@
 import pkg_resources
 import tensorflow as tf
 from bob.learn.tensorflow.dataset.siamese_image import shuffle_data_and_labels_image_augmentation as siamese_batch
+from bob.learn.tensorflow.dataset.triplet_image import shuffle_data_and_labels_image_augmentation as triplet_batch
 
 data_shape = (250, 250, 3)
 output_shape = (50, 50)
@@ -14,24 +15,25 @@ validation_batch_size = 250
 epochs = 1
 
 
-def test_siamese_dataset():
+# Trainer logits
+filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
+
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')]
+labels = [0, 0, 0, 0, 0, 0,
+          1, 1, 1, 1, 1, 1]
 
-    # Trainer logits
-    filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),                 
-                                  
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),                 
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),                 
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),                 
-                 pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')]
-    labels = [0, 0, 0, 0, 0, 0,
-              1, 1, 1, 1, 1, 1]
+
+def test_siamese_dataset():
 
     data, label = siamese_batch(filenames, labels, data_shape, data_type, 2, per_image_normalization=False, output_shape=output_shape)
 
@@ -41,3 +43,13 @@ def test_siamese_dataset():
         assert d['left'].shape == (2, 50, 50, 3)
         assert d['right'].shape == (2, 50, 50, 3)
 
+
+def test_triplet_dataset():
+
+    data = triplet_batch(filenames, labels, data_shape, data_type, 2, per_image_normalization=False, output_shape=output_shape)
+    with tf.Session() as session:
+        d = session.run([data])[0]
+        assert len(d.keys()) == 3
+        assert d['anchor'].shape == (2, 50, 50, 3)
+        assert d['positive'].shape == (2, 50, 50, 3)
+        assert d['negative'].shape == (2, 50, 50, 3)
diff --git a/bob/learn/tensorflow/test/test_estimator_onegraph.py b/bob/learn/tensorflow/test/test_estimator_onegraph.py
index 9a252d8d7fd15ac6fbf1dad31f527dc7076a4478..a81cf145454b70fc2cca9abf4a51cd489c7ea037 100755
--- a/bob/learn/tensorflow/test/test_estimator_onegraph.py
+++ b/bob/learn/tensorflow/test/test_estimator_onegraph.py
@@ -65,7 +65,8 @@ def test_logitstrainer_embedding():
                                 n_classes=10,
                                 loss_op=mean_cross_entropy_loss,
                                 embedding_validation=embedding_validation,
-                                validation_batch_size=validation_batch_size)    
+                                validation_batch_size=validation_batch_size)
+
         run_logitstrainer_mnist(trainer)
     finally:
         try:
@@ -153,7 +154,6 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
     create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=validation_batch_size)
 
     def input_fn():
-    
         if augmentation:
             return shuffle_data_and_labels_image_augmentation(tfrecord_train, data_shape, data_type, batch_size, epochs=epochs)
         else:
diff --git a/bob/learn/tensorflow/test/test_estimator_triplet.py b/bob/learn/tensorflow/test/test_estimator_triplet.py
new file mode 100755
index 0000000000000000000000000000000000000000..223fa13ef3f68b54444527132f3467262600b32d
--- /dev/null
+++ b/bob/learn/tensorflow/test/test_estimator_triplet.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+
+import tensorflow as tf
+
+from bob.learn.tensorflow.network import dummy
+from bob.learn.tensorflow.estimators import Triplet
+from bob.learn.tensorflow.dataset.triplet_image import shuffle_data_and_labels_image_augmentation as triplet_batch
+from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation as single_batch
+
+from bob.learn.tensorflow.loss import triplet_loss
+from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
+from bob.learn.tensorflow.utils import reproducible
+import pkg_resources
+
+import numpy
+import shutil
+import os
+
+
+tfrecord_train = "./train_mnist.tfrecord"
+tfrecord_validation = "./validation_mnist.tfrecord"    
+model_dir = "./temp"
+
+learning_rate = 0.001
+data_shape = (250, 250, 3)  # size of atnt images
+output_shape = (50, 50)
+data_type = tf.float32
+batch_size = 4
+validation_batch_size = 2
+epochs = 1
+steps = 5000
+
+
+# Data
+filenames = [pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),                 
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),                 
+
+                              
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),                 
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),                 
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),                 
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),                 
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
+             pkg_resources.resource_filename(__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png'),
+             ]
+labels = [0, 0, 0, 0, 0, 0, 0, 0, 0,
+          1, 1, 1, 1, 1, 1, 1, 1, 1]
+
+
+def test_triplet_estimator():
+    # Trainer logits
+    try:
+        trainer = Triplet(model_dir=model_dir,
+                          architecture=dummy,
+                          optimizer=tf.train.GradientDescentOptimizer(learning_rate),
+                          n_classes=10,
+                          loss_op=triplet_loss,
+                          validation_batch_size=validation_batch_size)
+        run_triplet_estimator(trainer)
+    finally:
+        try:
+            shutil.rmtree(model_dir, ignore_errors=True)
+            #pass
+        except Exception:
+            pass        
+
+
+def run_triplet_estimator(trainer):
+
+    # Cleaning up
+    tf.reset_default_graph()
+    assert len(tf.global_variables()) == 0
+
+    def input_fn():
+        return triplet_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape,
+                             random_flip=True, random_brightness=True, random_contrast=True, random_saturation=True)
+
+    def input_validation_fn():
+        return single_batch(filenames, labels, data_shape, data_type, validation_batch_size, epochs=10, output_shape=output_shape)
+
+    hooks = [LoggerHookEstimator(trainer, batch_size, 300),
+
+             tf.train.SummarySaverHook(save_steps=1000,
+                                       output_dir=model_dir,
+                                       scaffold=tf.train.Scaffold(),
+                                       summary_writer=tf.summary.FileWriter(model_dir) )]
+
+    trainer.train(input_fn, steps=steps, hooks=hooks)
+
+    acc = trainer.evaluate(input_validation_fn)
+    assert acc['accuracy'] > 0.5
+
+    # Cleaning up
+    tf.reset_default_graph()
+    assert len(tf.global_variables()) == 0
+
diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py
index 6fdaeda95fb143844202bf21c5efc5808b250f1f..d809c2ee079599996455b405cc2143147bdbcd2e 100755
--- a/bob/learn/tensorflow/utils/util.py
+++ b/bob/learn/tensorflow/utils/util.py
@@ -60,9 +60,9 @@ def create_mnist_tfrecord(tfrecords_filename, data, labels, n_samples=6000):
     for i in range(n_samples):
         img = data[i]
         img_raw = img.tostring()
-
-        feature = {'train/data': _bytes_feature(img_raw),
-                   'train/label': _int64_feature(labels[i])
+        feature = {'data': _bytes_feature(img_raw),
+                   'label': _int64_feature(labels[i]),
+                   'key': _bytes_feature(b'-')
                    }
 
         example = tf.train.Example(features=tf.train.Features(feature=feature))