diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e0221e60a7a94e340d318532aee28e0fa9b469a7
--- /dev/null
+++ b/bob/learn/tensorflow/dataset/__init__.py
@@ -0,0 +1,71 @@
+import tensorflow as tf
+
+
+DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string),
+                   'train/label': tf.FixedLenFeature([], tf.int64)}
+
+
+
+def append_image_augmentation(image, gray_scale=False, 
+                              output_shape=None,
+                              random_flip=False,
+                              random_brightness=False,
+                              random_contrast=False,
+                              random_saturation=False,
+                              per_image_normalization=True):
+    """
+    Append to the current tensor some random image augmentation operation
+    
+    **Parameters**
+       gray_scale:
+          Convert to gray scale?
+          
+       output_shape:
+          If set, will randomly crop the image given the output shape
+
+       random_flip:
+          Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
+
+       random_brightness:
+           Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
+
+       random_contrast:
+           Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
+
+       random_saturation:
+           Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
+
+       per_image_normalization:
+           Linearly scales image to have zero mean and unit norm.
+       
+    """
+
+    # Casting to float32
+    image = tf.cast(image, tf.float32)
+
+    if output_shape is not None:
+        assert output_shape.ndim == 2        
+        image = tf.image.resize_image_with_crop_or_pad(image, output_shape[0], output_shape[1])
+        
+    if random_flip:
+        image = tf.image.random_flip_left_right(image)
+
+    if random_brightness:
+        image = tf.image.random_brightness(image)
+
+    if random_contrast:
+        image = tf.image.random_contrast(image)
+
+    if random_saturation:
+        image = tf.image.random_saturation(image)
+
+    if gray_scale:
+        image = tf.image.rgb_to_grayscale(image, name="rgb_to_gray")
+        #self.output_shape[3] = 1
+
+    # normalizing data
+    if per_image_normalization:
+        image = tf.image.per_image_standardization(image)
+
+    return image
+
diff --git a/bob/learn/tensorflow/dataset/image.py b/bob/learn/tensorflow/dataset/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c5930500b9d747ca5566fad9a93bd6fbd11f6b9
--- /dev/null
+++ b/bob/learn/tensorflow/dataset/image.py
@@ -0,0 +1,166 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+
+import tensorflow as tf
+from functools import partial
+from . import append_image_augmentation
+
+
+def shuffle_data_and_labels_image_augmentation(filenames, labels, data_shape, data_type,
+                                              batch_size, epochs=None, buffer_size=10**3,
+                                              gray_scale=False, 
+                                              output_shape=None,
+                                              random_flip=False,
+                                              random_brightness=False,
+                                              random_contrast=False,
+                                              random_saturation=False,
+                                              per_image_normalization=True):
+    """
+    Dump random batches from a list of image paths and labels:
+        
+    The list of files and labels should be in the same order e.g.
+    filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1']
+    labels = [0, 0, 1]
+    
+
+    **Parameters**
+
+       filenames:
+          List containing the path of the images
+       
+       labels:
+          List containing the labels (needs to be in EXACT same order as filenames)
+          
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     
+       batch_size:
+          Size of the batch
+          
+       epochs:
+           Number of epochs to be batched
+       
+       buffer_size:
+            Size of the shuffle bucket
+
+       gray_scale:
+          Convert to gray scale?
+          
+       output_shape:
+          If set, will randomly crop the image given the output shape
+
+       random_flip:
+          Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
+
+       random_brightness:
+           Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
+
+       random_contrast:
+           Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
+
+       random_saturation:
+           Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
+
+       per_image_normalization:
+           Linearly scales image to have zero mean and unit norm.            
+     
+    """                            
+
+    dataset = create_dataset_from_path_augmentation(filenames, labels, data_shape,
+                                          data_type,
+                                          gray_scale=gray_scale, 
+                                          output_shape=output_shape,
+                                          random_flip=random_flip,
+                                          random_brightness=random_brightness,
+                                          random_contrast=random_contrast,
+                                          random_saturation=random_saturation,
+                                          per_image_normalization=per_image_normalization)
+                                          
+    dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
+
+    data, labels = dataset.make_one_shot_iterator().get_next()
+    return data, labels
+
+
+def create_dataset_from_path_augmentation(filenames, labels,
+                                          data_shape, data_type,
+                                          gray_scale=False, 
+                                          output_shape=None,
+                                          random_flip=False,
+                                          random_brightness=False,
+                                          random_contrast=False,
+                                          random_saturation=False,
+                                          per_image_normalization=True):
+    """
+    Create dataset from a list of tf-record files
+    
+    **Parameters**
+    
+       filenames:
+          List containing the path of the images
+       
+       labels:
+          List containing the labels (needs to be in EXACT same order as filenames)
+          
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+          
+       feature:
+    
+    """
+ 
+    parser = partial(image_augmentation_parser,
+                     data_shape=data_shape,
+                     data_type=data_type,
+                     gray_scale=gray_scale, 
+                     output_shape=output_shape,
+                     random_flip=random_flip,
+                     random_brightness=random_brightness,
+                     random_contrast=random_contrast,
+                     random_saturation=random_saturation,
+                     per_image_normalization=per_image_normalization) 
+
+    dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels))
+    dataset = dataset.map(parser)
+    return dataset
+
+
+def image_augmentation_parser(filename, label, data_shape, data_type,
+                              gray_scale=False, 
+                              output_shape=None,
+                              random_flip=False,
+                              random_brightness=False,
+                              random_contrast=False,
+                              random_saturation=False,
+                              per_image_normalization=True):
+
+    """
+    Parses a single tf.Example into image and label tensors.
+    """
+        
+    # Convert the image data from string back to the numbers
+    image = tf.cast(tf.image.decode_image(tf.read_file(filename)), tf.float32)
+
+    # Reshape image data into the original shape
+    image = tf.reshape(image, data_shape)
+    
+    #Applying image augmentation
+    image = append_image_augmentation(image, gray_scale=gray_scale,
+                                      output_shape=output_shape,
+                                      random_flip=random_flip,
+                                      random_brightness=random_brightness,
+                                      random_contrast=random_contrast,
+                                      random_saturation=random_saturation,
+                                      per_image_normalization=per_image_normalization)
+                                        
+    label = tf.cast(label, tf.int64)
+
+    return image, label
+
diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e4713110e265dbe6969899d680c6d668301ea8f
--- /dev/null
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -0,0 +1,281 @@
+from functools import partial
+import tensorflow as tf
+from . import append_image_augmentation, DEFAULT_FEATURE
+
+
+def example_parser(serialized_example, feature, data_shape, data_type):
+    """
+    Parses a single tf.Example into image and label tensors.
+    
+    """
+    # Decode the record read by the reader
+    features = tf.parse_single_example(serialized_example, features=feature)
+    # Convert the image data from string back to the numbers
+    image = tf.decode_raw(features['train/data'], data_type)
+    # Cast label data into int64
+    label = tf.cast(features['train/label'], tf.int64)
+    # Reshape image data into the original shape
+    image = tf.reshape(image, data_shape)
+    return image, label
+
+
+def image_augmentation_parser(serialized_example, feature, data_shape, data_type,
+                              gray_scale=False, 
+                              output_shape=None,
+                              random_flip=False,
+                              random_brightness=False,
+                              random_contrast=False,
+                              random_saturation=False,
+                              per_image_normalization=True):
+
+    """
+    Parses a single tf.Example into image and label tensors.
+    
+    """
+    # Decode the record read by the reader
+    features = tf.parse_single_example(serialized_example, features=feature)
+    # Convert the image data from string back to the numbers
+    image = tf.decode_raw(features['train/data'], data_type)
+
+    # Reshape image data into the original shape
+    image = tf.reshape(image, data_shape)
+    
+    #Applying image augmentation
+    image = append_image_augmentation(image, gray_scale=gray_scale,
+                                      output_shape=output_shape,
+                                      random_flip=random_flip,
+                                      random_brightness=random_brightness,
+                                      random_contrast=random_contrast,
+                                      random_saturation=random_saturation,
+                                      per_image_normalization=per_image_normalization)
+    
+    # Cast label data into int64
+    label = tf.cast(features['train/label'], tf.int64)
+    return image, label
+
+
+def read_and_decode(filename_queue, data_shape, data_type=tf.float32,
+                    feature=None):
+                    
+    """
+    Simples parse possible for a tfrecord.
+    It assumes that you have the pair **train/data** and **train/label**
+    """
+                    
+    if feature is None:
+        feature = DEFAULT_FEATURE
+    # Define a reader and read the next record
+    reader = tf.TFRecordReader()
+    _, serialized_example = reader.read(filename_queue)
+    return example_parser(serialized_example, feature, data_shape, data_type)
+
+
+def create_dataset_from_records(tfrecord_filenames, data_shape, data_type,
+                                feature=None):
+    """
+    Create dataset from a list of tf-record files
+    
+    **Parameters**
+    
+       tfrecord_filenames: 
+          List containing the tf-record paths
+
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+          
+       feature:
+    
+    """
+                                
+    if feature is None:
+        feature = DEFAULT_FEATURE
+    dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames)
+    parser = partial(example_parser, feature=feature, data_shape=data_shape,
+                     data_type=data_type)
+    dataset = dataset.map(parser)
+    return dataset
+
+
+def create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape, data_type,
+                                feature=None,
+                                gray_scale=False, 
+                                output_shape=None,
+                                random_flip=False,
+                                random_brightness=False,
+                                random_contrast=False,
+                                random_saturation=False,
+                                per_image_normalization=True):
+    """
+    Create dataset from a list of tf-record files
+    
+    **Parameters**
+    
+       tfrecord_filenames: 
+          List containing the tf-record paths
+
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+          
+       feature:
+    
+    """
+                                
+                                
+    if feature is None:
+        feature = DEFAULT_FEATURE
+    dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames)
+    parser = partial(image_augmentation_parser, feature=feature, data_shape=data_shape,
+                     data_type=data_type,
+                     gray_scale=gray_scale, 
+                     output_shape=output_shape,
+                     random_flip=random_flip,
+                     random_brightness=random_brightness,
+                     random_contrast=random_contrast,
+                     random_saturation=random_saturation,
+                     per_image_normalization=per_image_normalization)
+    dataset = dataset.map(parser)
+    return dataset
+
+
+def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type,
+                                              batch_size, epochs=None, buffer_size=10**3,
+                                              gray_scale=False, 
+                                              output_shape=None,
+                                              random_flip=False,
+                                              random_brightness=False,
+                                              random_contrast=False,
+                                              random_saturation=False,
+                                              per_image_normalization=True):
+    """
+    Dump random batches from a list of tf-record files and applies some image augmentation
+
+    **Parameters**
+
+       tfrecord_filenames: 
+          List containing the tf-record paths
+
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     
+       batch_size:
+          Size of the batch
+          
+       epochs:
+           Number of epochs to be batched
+       
+       buffer_size:
+            Size of the shuffle bucket
+
+       gray_scale:
+          Convert to gray scale?
+          
+       output_shape:
+          If set, will randomly crop the image given the output shape
+
+       random_flip:
+          Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
+
+       random_brightness:
+           Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
+
+       random_contrast:
+           Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
+
+       random_saturation:
+           Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
+
+       per_image_normalization:
+           Linearly scales image to have zero mean and unit norm.            
+     
+    """                            
+
+    dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape,
+                                          data_type,
+                                          gray_scale=gray_scale, 
+                                          output_shape=output_shape,
+                                          random_flip=random_flip,
+                                          random_brightness=random_brightness,
+                                          random_contrast=random_contrast,
+                                          random_saturation=random_saturation,
+                                          per_image_normalization=per_image_normalization)
+                                          
+    dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
+
+    data, labels = dataset.make_one_shot_iterator().get_next()
+    return data, labels
+
+
+def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type,
+                            batch_size, epochs=None, buffer_size=10**3):
+    """
+    Dump random batches from a list of tf-record files
+
+    **Parameters**
+
+       tfrecord_filenames: 
+          List containing the tf-record paths
+
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     
+       batch_size:
+          Size of the batch
+          
+       epochs:
+           Number of epochs to be batched
+       
+       buffer_size:
+            Size of the shuffle bucket
+     
+    """                            
+
+    dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
+                                          data_type)
+    dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
+
+    data, labels = dataset.make_one_shot_iterator().get_next()
+    return data, labels
+
+
+def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
+                          batch_size, epochs=1):
+    """
+    Dump in order batches from a list of tf-record files
+
+    **Parameters**
+
+       tfrecord_filenames: 
+          List containing the tf-record paths
+
+       data_shape:
+          Samples shape saved in the tf-record
+          
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     
+       batch_size:
+          Size of the batch
+          
+       epochs:
+           Number of epochs to be batched
+     
+    """                             
+    dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
+                                          data_type)
+    dataset = dataset.batch(batch_size).repeat(epochs)
+
+    data, labels = dataset.make_one_shot_iterator().get_next()
+    return data, labels
+
diff --git a/bob/learn/tensorflow/test/test_estimator_scripts.py b/bob/learn/tensorflow/test/test_estimator_scripts.py
index 38a8d89031f2da2eef92f3cbc73c835c52f2c083..d45c847fc9a92d5b6c41a11d1ca28d680ead09c3 100644
--- a/bob/learn/tensorflow/test/test_estimator_scripts.py
+++ b/bob/learn/tensorflow/test/test_estimator_scripts.py
@@ -14,7 +14,7 @@ from bob.learn.tensorflow.script.eval_generic import main as eval_generic
 dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__)
 CONFIG = '''
 import tensorflow as tf
-from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels, \
+from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, \
     batch_data_and_labels
 
 model_dir = "%(model_dir)s"
diff --git a/bob/learn/tensorflow/test/test_onegraph_estimator.py b/bob/learn/tensorflow/test/test_onegraph_estimator.py
new file mode 100755
index 0000000000000000000000000000000000000000..33cadd805f3b55f9e8866302f74d561913e1b465
--- /dev/null
+++ b/bob/learn/tensorflow/test/test_onegraph_estimator.py
@@ -0,0 +1,187 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+
+import tensorflow as tf
+
+from bob.learn.tensorflow.network import dummy
+from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer
+
+from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation
+
+
+from bob.learn.tensorflow.dataset import append_image_augmentation
+from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
+from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
+from bob.learn.tensorflow.utils import reproducible
+from bob.learn.tensorflow.loss import mean_cross_entropy_loss
+
+import numpy
+
+import shutil
+import os
+
+
+tfrecord_train = "./train_mnist.tfrecord"
+tfrecord_validation = "./validation_mnist.tfrecord"    
+model_dir = "./temp"
+
+learning_rate = 0.1
+data_shape = (28, 28, 1)  # size of atnt images
+data_type = tf.float32
+batch_size = 16
+validation_batch_size = 250
+epochs = 1
+steps = 5000
+
+
+def test_logitstrainer():
+    # Trainer logits
+    try:
+        embedding_validation = False
+        trainer = LogitsTrainer(model_dir=model_dir,
+                                architecture=dummy,
+                                optimizer=tf.train.GradientDescentOptimizer(learning_rate),
+                                n_classes=10,
+                                loss_op=mean_cross_entropy_loss,
+                                embedding_validation=embedding_validation,
+                                validation_batch_size=validation_batch_size)
+        run_logitstrainer_mnist(trainer, augmentation=True)
+    finally:
+        try:
+            os.unlink(tfrecord_train)
+            os.unlink(tfrecord_validation)
+            shutil.rmtree(model_dir, ignore_errors=True)
+        except Exception:
+            pass        
+
+
+def test_logitstrainer_embedding():
+    try:
+        embedding_validation = True
+        trainer = LogitsTrainer(model_dir=model_dir,
+                                architecture=dummy,
+                                optimizer=tf.train.GradientDescentOptimizer(learning_rate),
+                                n_classes=10,
+                                loss_op=mean_cross_entropy_loss,
+                                embedding_validation=embedding_validation,
+                                validation_batch_size=validation_batch_size)    
+        run_logitstrainer_mnist(trainer)
+    finally:
+        try:
+            os.unlink(tfrecord_train)
+            os.unlink(tfrecord_validation)
+            shutil.rmtree(model_dir, ignore_errors=True)
+        except Exception:
+            pass        
+
+
+def test_logitstrainer_centerloss():
+
+    try:
+        embedding_validation = False
+        run_config = tf.estimator.RunConfig()
+        run_config = run_config.replace(save_checkpoints_steps=1000)
+        trainer = LogitsCenterLossTrainer(
+                                model_dir=model_dir,
+                                architecture=dummy,
+                                optimizer=tf.train.GradientDescentOptimizer(learning_rate),
+                                n_classes=10,
+                                embedding_validation=embedding_validation,
+                                validation_batch_size=validation_batch_size,
+                                factor=0.01,
+                                config=run_config)
+                                
+        run_logitstrainer_mnist(trainer)
+
+        # Checking if the centers were updated
+        sess = tf.Session()
+        checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path
+        saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True)
+        saver.restore(sess, tf.train.latest_checkpoint(model_dir))
+        centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0]
+        assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0    
+
+        
+    finally:
+        try:
+            os.unlink(tfrecord_train)
+            os.unlink(tfrecord_validation)
+            shutil.rmtree(model_dir, ignore_errors=True)
+        except Exception:
+            pass
+
+
+def test_logitstrainer_centerloss_embedding():
+    try:
+        embedding_validation = True
+        trainer = LogitsCenterLossTrainer(
+                                model_dir=model_dir,
+                                architecture=dummy,
+                                optimizer=tf.train.GradientDescentOptimizer(learning_rate),
+                                n_classes=10,
+                                embedding_validation=embedding_validation,
+                                validation_batch_size=validation_batch_size,
+                                factor=0.01)
+        run_logitstrainer_mnist(trainer)
+        
+        # Checking if the centers were updated
+        sess = tf.Session()
+        checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path
+        saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True)
+        saver.restore(sess, tf.train.latest_checkpoint(model_dir))
+        centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0]
+        assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0    
+    finally:
+        try:
+            os.unlink(tfrecord_train)
+            os.unlink(tfrecord_validation)
+            shutil.rmtree(model_dir, ignore_errors=True)
+        except Exception:
+            pass        
+
+
+def run_logitstrainer_mnist(trainer, augmentation=False):
+
+    # Cleaning up
+    tf.reset_default_graph()
+    assert len(tf.global_variables()) == 0
+
+    # Creating tf records for mnist
+    train_data, train_labels, validation_data, validation_labels = load_mnist()
+    create_mnist_tfrecord(tfrecord_train, train_data, train_labels, n_samples=6000)
+    create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=validation_batch_size)
+
+    def input_fn():
+    
+        if augmentation:
+            return shuffle_data_and_labels_image_augmentation(tfrecord_train, data_shape, data_type, batch_size, epochs=epochs)
+        else:
+            return shuffle_data_and_labels(tfrecord_train, data_shape, data_type,
+                                           batch_size, epochs=epochs)
+        
+
+    def input_fn_validation():
+        return batch_data_and_labels(tfrecord_validation, data_shape, data_type,
+                                     validation_batch_size, epochs=1000)
+    
+    hooks = [LoggerHookEstimator(trainer, 16, 300),
+
+             tf.train.SummarySaverHook(save_steps=1000,
+                                       output_dir=model_dir,
+                                       scaffold=tf.train.Scaffold(),
+                                       summary_writer=tf.summary.FileWriter(model_dir) )]
+
+    trainer.train(input_fn, steps=steps, hooks=hooks)
+
+    if not trainer.embedding_validation:
+        acc = trainer.evaluate(input_fn_validation)
+        assert acc['accuracy'] > 0.80
+    else:
+        acc = trainer.evaluate(input_fn_validation)
+        assert acc['accuracy'] > 0.80
+
+    # Cleaning up
+    tf.reset_default_graph()
+    assert len(tf.global_variables()) == 0
+
diff --git a/bob/learn/tensorflow/test/test_onegraph_model_fn.py b/bob/learn/tensorflow/test/test_onegraph_model_fn.py
deleted file mode 100755
index 2b73e88a5035bc784777ac7787b2240ddf4387e0..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/test_onegraph_model_fn.py
+++ /dev/null
@@ -1,162 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-import tensorflow as tf
-
-from bob.learn.tensorflow.network import dummy
-from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer
-from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels, batch_data_and_labels
-from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
-from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
-from bob.learn.tensorflow.loss import mean_cross_entropy_loss
-import numpy
-
-import shutil
-import os
-
-
-tfrecord_train = "./train_mnist.tfrecord"
-tfrecord_validation = "./validation_mnist.tfrecord"    
-model_dir = "./temp"
-
-learning_rate = 0.1
-data_shape = (28, 28, 1)  # size of atnt images
-data_type = tf.float32
-batch_size = 16
-validation_batch_size = 1000
-epochs = 1
-steps = 2000
-
-
-def test_logitstrainer():
-    run_logitstrainer(False)
-
-
-def test_logitstrainer_embedding():
-    run_logitstrainer(True)
-
-
-def test_logitstrainer_centerloss():
-    run_logitstrainer_centerloss(False)
-
-
-def test_logitstrainer_centerloss_embedding():
-    run_logitstrainer_centerloss(True)
-
-
-def run_logitstrainer(embedding_validation):
-
-    # Cleaning up
-    tf.reset_default_graph()
-    assert len(tf.global_variables()) == 0
-
-    # Creating tf records for mnist
-    train_data, train_labels, validation_data, validation_labels = load_mnist()
-    create_mnist_tfrecord(tfrecord_train, train_data, train_labels, n_samples=6000)
-    create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=1000)
-
-    try:
-        
-        # Trainer logits
-        trainer = LogitsTrainer(model_dir=model_dir,
-                                architecture=dummy,
-                                optimizer=tf.train.GradientDescentOptimizer(learning_rate),
-                                n_classes=10,
-                                loss_op=mean_cross_entropy_loss,
-                                embedding_validation=embedding_validation,
-                                validation_batch_size=validation_batch_size
-                                )
-
-        def input_fn():
-            return shuffle_data_and_labels(tfrecord_train, data_shape, data_type,
-                                           batch_size, epochs=epochs)
-                                       
-        def input_fn_validation():
-            return batch_data_and_labels(tfrecord_validation, data_shape, data_type,
-                                         validation_batch_size, epochs=epochs)                                       
-
-        hooks = [LoggerHookEstimator(trainer, 16, 100)]
-        trainer.train(input_fn, steps=steps, hooks=hooks)
-        
-        if not embedding_validation:
-
-            acc = trainer.evaluate(input_fn_validation)
-            assert acc['accuracy'] > 0.80
-        else:
-            acc = trainer.evaluate(input_fn_validation)
-            assert acc['accuracy'] > 0.80
-
-    finally:
-        try:
-            os.unlink(tfrecord_train)
-            os.unlink(tfrecord_validation)            
-            shutil.rmtree(model_dir)
-        except Exception:
-            pass
-
-    # Cleaning up
-    tf.reset_default_graph()
-    assert len(tf.global_variables()) == 0
-
-
-def run_logitstrainer_centerloss(embedding_validation):
-
-    # Cleaning up
-    tf.reset_default_graph()
-    assert len(tf.global_variables()) == 0
-
-    # Creating tf records for mnist
-    train_data, train_labels, validation_data, validation_labels = load_mnist()
-    create_mnist_tfrecord(tfrecord_train, train_data, train_labels, n_samples=6000)
-    create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=1000)
-
-    try:
-
-        # Trainer logits
-        trainer = LogitsCenterLossTrainer(
-                                model_dir=model_dir,
-                                architecture=dummy,
-                                optimizer=tf.train.GradientDescentOptimizer(learning_rate),
-                                n_classes=10,
-                                embedding_validation=embedding_validation,
-                                validation_batch_size=validation_batch_size,
-                                factor=0.01
-                                )
-
-        def input_fn():
-            return shuffle_data_and_labels(tfrecord_train, data_shape, data_type,
-                                           batch_size, epochs=epochs)
-
-        def input_fn_validation():
-            return batch_data_and_labels(tfrecord_validation, data_shape, data_type,
-                                         validation_batch_size, epochs=epochs)
-
-        hooks = [LoggerHookEstimator(trainer, 16, 100)]
-        trainer.train(input_fn, steps=steps, hooks=hooks)
-
-        if not embedding_validation:
-            acc = trainer.evaluate(input_fn_validation)
-            assert acc['accuracy'] > 0.80
-        else:
-            acc = trainer.evaluate(input_fn_validation)
-            assert acc['accuracy'] > 0.80
-
-        sess = tf.Session()
-        checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path
-        saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True)
-        saver.restore(sess, tf.train.latest_checkpoint(model_dir))
-        centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0]
-        assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0
-
-    finally:
-        try:
-            os.unlink(tfrecord_train)
-            os.unlink(tfrecord_validation)
-            shutil.rmtree(model_dir)
-        except Exception:
-            pass
-
-    # Cleaning up
-    tf.reset_default_graph()
-    assert len(tf.global_variables()) == 0
diff --git a/bob/learn/tensorflow/utils/tfrecords.py b/bob/learn/tensorflow/utils/tfrecords.py
deleted file mode 100644
index 48da0740577c2a64e4e6f59b26dac959e0a0678f..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/utils/tfrecords.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from functools import partial
-import tensorflow as tf
-
-
-DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string),
-                   'train/label': tf.FixedLenFeature([], tf.int64)}
-
-
-def example_parser(serialized_example, feature, data_shape, data_type):
-    """Parses a single tf.Example into image and label tensors."""
-    # Decode the record read by the reader
-    features = tf.parse_single_example(serialized_example, features=feature)
-    # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features['train/data'], data_type)
-    # Cast label data into int64
-    label = tf.cast(features['train/label'], tf.int64)
-    # Reshape image data into the original shape
-    image = tf.reshape(image, data_shape)
-    return image, label
-
-
-def read_and_decode(filename_queue, data_shape, data_type=tf.float32,
-                    feature=None):
-    if feature is None:
-        feature = DEFAULT_FEATURE
-    # Define a reader and read the next record
-    reader = tf.TFRecordReader()
-    _, serialized_example = reader.read(filename_queue)
-    return example_parser(serialized_example, feature, data_shape, data_type)
-
-
-def create_dataset_from_records(tfrecord_filenames, data_shape, data_type,
-                                feature=None):
-    if feature is None:
-        feature = DEFAULT_FEATURE
-    dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames)
-    parser = partial(example_parser, feature=feature, data_shape=data_shape,
-                     data_type=data_type)
-    dataset = dataset.map(parser)
-    return dataset
-
-
-def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type,
-                            batch_size, epochs=None, buffer_size=10**3):
-    dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
-                                          data_type)
-    dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
-
-    datas, labels = dataset.make_one_shot_iterator().get_next()
-    return datas, labels
-
-
-def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
-                          batch_size, epochs=1):
-    dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
-                                          data_type)
-    dataset = dataset.batch(batch_size).repeat(epochs)
-
-    datas, labels = dataset.make_one_shot_iterator().get_next()
-    return datas, labels