diff --git a/bob/learn/tensorflow/examples/mnist/mnist_config.py b/bob/learn/tensorflow/examples/mnist/mnist_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6227fcd462822776320db3c4e24a646f3ef2c721
--- /dev/null
+++ b/bob/learn/tensorflow/examples/mnist/mnist_config.py
@@ -0,0 +1,204 @@
+#  Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from bob.learn.tensorflow.utils.reproducible import session_conf
+import tensorflow as tf
+
+model_dir = '/tmp/mnist_model'
+train_tfrecords = ['/tmp/mnist_data/train.tfrecords']
+eval_tfrecords = ['/tmp/mnist_data/test.tfrecords']
+
+# by default create reproducible nets:
+run_config = tf.estimator.RunConfig()
+run_config = run_config.replace(session_config=session_conf)
+run_config = run_config.replace(keep_checkpoint_max=10**3)
+run_config = run_config.replace(save_checkpoints_secs=60)
+
+
+def input_fn(mode, batch_size=1):
+    """A simple input_fn using the contrib.data input pipeline."""
+
+    def example_parser(serialized_example):
+        """Parses a single tf.Example into image and label tensors."""
+        features = tf.parse_single_example(
+            serialized_example,
+            features={
+                'image_raw': tf.FixedLenFeature([], tf.string),
+                'label': tf.FixedLenFeature([], tf.int64),
+            })
+        image = tf.decode_raw(features['image_raw'], tf.uint8)
+        image.set_shape([28 * 28])
+
+        # Normalize the values of the image from the range
+        # [0, 255] to [-0.5, 0.5]
+        image = tf.cast(image, tf.float32) / 255 - 0.5
+        label = tf.cast(features['label'], tf.int32)
+        return image, tf.one_hot(label, 10)
+
+    if mode == tf.estimator.ModeKeys.TRAIN:
+        tfrecords_files = train_tfrecords
+    else:
+        assert mode == tf.estimator.ModeKeys.EVAL, 'invalid mode'
+        tfrecords_files = eval_tfrecords
+
+    for tfrecords_file in tfrecords_files:
+        assert tf.gfile.Exists(tfrecords_file), (
+            'Run github.com:tensorflow/models/official/mnist/'
+            'convert_to_records.py first to convert the MNIST data to '
+            'TFRecord file format.')
+
+    dataset = tf.contrib.data.TFRecordDataset(tfrecords_files)
+
+    # For training, repeat the dataset forever
+    if mode == tf.estimator.ModeKeys.TRAIN:
+        dataset = dataset.repeat()
+
+    # Map example_parser over dataset, and batch results by up to batch_size
+    dataset = dataset.map(
+        example_parser, num_threads=1, output_buffer_size=batch_size)
+    dataset = dataset.batch(batch_size)
+    images, labels = dataset.make_one_shot_iterator().get_next()
+
+    return images, labels
+
+
+def train_input_fn():
+    return input_fn(tf.estimator.ModeKeys.TRAIN)
+
+
+def eval_input_fn():
+    return input_fn(tf.estimator.ModeKeys.EVAL)
+
+
+def mnist_model(inputs, mode):
+    """Takes the MNIST inputs and mode and outputs a tensor of logits."""
+    # Input Layer
+    # Reshape X to 4-D tensor: [batch_size, width, height, channels]
+    # MNIST images are 28x28 pixels, and have one color channel
+    inputs = tf.reshape(inputs, [-1, 28, 28, 1])
+    data_format = 'channels_last'
+
+    if tf.test.is_built_with_cuda():
+        # When running on GPU, transpose the data from channels_last (NHWC) to
+        # channels_first (NCHW) to improve performance. See
+        # https://www.tensorflow.org/performance/performance_guide#data_formats
+        data_format = 'channels_first'
+        inputs = tf.transpose(inputs, [0, 3, 1, 2])
+
+    # Convolutional Layer #1
+    # Computes 32 features using a 5x5 filter with ReLU activation.
+    # Padding is added to preserve width and height.
+    # Input Tensor Shape: [batch_size, 28, 28, 1]
+    # Output Tensor Shape: [batch_size, 28, 28, 32]
+    conv1 = tf.layers.conv2d(
+        inputs=inputs,
+        filters=32,
+        kernel_size=[5, 5],
+        padding='same',
+        activation=tf.nn.relu,
+        data_format=data_format)
+
+    # Pooling Layer #1
+    # First max pooling layer with a 2x2 filter and stride of 2
+    # Input Tensor Shape: [batch_size, 28, 28, 32]
+    # Output Tensor Shape: [batch_size, 14, 14, 32]
+    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2,
+                                    data_format=data_format)
+
+    # Convolutional Layer #2
+    # Computes 64 features using a 5x5 filter.
+    # Padding is added to preserve width and height.
+    # Input Tensor Shape: [batch_size, 14, 14, 32]
+    # Output Tensor Shape: [batch_size, 14, 14, 64]
+    conv2 = tf.layers.conv2d(
+        inputs=pool1,
+        filters=64,
+        kernel_size=[5, 5],
+        padding='same',
+        activation=tf.nn.relu,
+        data_format=data_format)
+
+    # Pooling Layer #2
+    # Second max pooling layer with a 2x2 filter and stride of 2
+    # Input Tensor Shape: [batch_size, 14, 14, 64]
+    # Output Tensor Shape: [batch_size, 7, 7, 64]
+    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2,
+                                    data_format=data_format)
+
+    # Flatten tensor into a batch of vectors
+    # Input Tensor Shape: [batch_size, 7, 7, 64]
+    # Output Tensor Shape: [batch_size, 7 * 7 * 64]
+    pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
+
+    # Dense Layer
+    # Densely connected layer with 1024 neurons
+    # Input Tensor Shape: [batch_size, 7 * 7 * 64]
+    # Output Tensor Shape: [batch_size, 1024]
+    dense = tf.layers.dense(inputs=pool2_flat, units=1024,
+                            activation=tf.nn.relu)
+
+    # Add dropout operation; 0.6 probability that element will be kept
+    dropout = tf.layers.dropout(
+        inputs=dense, rate=0.4, training=(mode == tf.estimator.ModeKeys.TRAIN))
+
+    # Logits layer
+    # Input Tensor Shape: [batch_size, 1024]
+    # Output Tensor Shape: [batch_size, 10]
+    logits = tf.layers.dense(inputs=dropout, units=10)
+    return logits
+
+
+def model_fn(features, labels, mode):
+    """Model function for MNIST."""
+    logits = mnist_model(features, mode)
+
+    predictions = {
+        'classes': tf.argmax(input=logits, axis=1),
+        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
+    }
+
+    if mode == tf.estimator.ModeKeys.PREDICT:
+        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
+
+    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
+
+    # Configure the training op
+    if mode == tf.estimator.ModeKeys.TRAIN:
+        optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
+        train_op = optimizer.minimize(
+            loss, tf.train.get_or_create_global_step())
+    else:
+        train_op = None
+
+    accuracy = tf.metrics.accuracy(
+        tf.argmax(labels, axis=1), predictions['classes'])
+    metrics = {'accuracy': accuracy}
+
+    with tf.name_scope('train_metrics'):
+        # Create a tensor named train_accuracy for logging purposes
+        tf.summary.scalar('train_accuracy', accuracy[1])
+
+        tf.summary.scalar('train_loss', loss)
+
+    return tf.estimator.EstimatorSpec(
+        mode=mode,
+        predictions=predictions,
+        loss=loss,
+        train_op=train_op,
+        eval_metric_ops=metrics)
diff --git a/bob/learn/tensorflow/network/SimpleCNN.py b/bob/learn/tensorflow/network/SimpleCNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..01bec65179bc7c5f96365afb4b2e403bcaff0782
--- /dev/null
+++ b/bob/learn/tensorflow/network/SimpleCNN.py
@@ -0,0 +1,107 @@
+import tensorflow as tf
+
+
+def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN):
+    # TODO: figure out a way to accept different input sizes
+
+    # Convolutional Layer #1
+    # Computes 32 features using a 5x5 filter with ReLU activation.
+    # Padding is added to preserve width and height.
+    # Input Tensor Shape: [batch_size, 50, 1024, 1]
+    # Output Tensor Shape: [batch_size, 50, 1024, 32]
+    conv1 = tf.layers.conv2d(
+        inputs=input_layer,
+        filters=32,
+        kernel_size=[5, 5],
+        padding="same",
+        activation=tf.nn.relu)
+
+    # Pooling Layer #1
+    # First max pooling layer with a 2x2 filter and stride of 2
+    # Input Tensor Shape: [batch_size, 50, 1024, 32]
+    # Output Tensor Shape: [batch_size, 25, 512, 32]
+    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
+
+    # Convolutional Layer #2
+    # Computes 64 features using a 5x5 filter.
+    # Padding is added to preserve width and height.
+    # Input Tensor Shape: [batch_size, 25, 512, 32]
+    # Output Tensor Shape: [batch_size, 25, 512, 64]
+    conv2 = tf.layers.conv2d(
+        inputs=pool1,
+        filters=64,
+        kernel_size=[5, 5],
+        padding="same",
+        activation=tf.nn.relu)
+
+    # Pooling Layer #2
+    # Second max pooling layer with a 2x2 filter and stride of 2
+    # Input Tensor Shape: [batch_size, 25, 512, 64]
+    # Output Tensor Shape: [batch_size, 12, 256, 64]
+    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
+
+    # Flatten tensor into a batch of vectors
+    # Input Tensor Shape: [batch_size, 12, 256, 64]
+    # Output Tensor Shape: [batch_size, 12 * 256 * 64]
+    pool2_flat = tf.reshape(pool2, [-1, 12 * 256 * 64])
+
+    # Dense Layer
+    # Densely connected layer with 1024 neurons
+    # Input Tensor Shape: [batch_size, 12 * 256 * 64]
+    # Output Tensor Shape: [batch_size, 1024]
+    dense = tf.layers.dense(
+        inputs=pool2_flat, units=1024, activation=tf.nn.relu)
+
+    # Add dropout operation; 0.6 probability that element will be kept
+    dropout = tf.layers.dropout(
+        inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
+
+    # Logits layer
+    # Input Tensor Shape: [batch_size, 1024]
+    # Output Tensor Shape: [batch_size, 2]
+    logits = tf.layers.dense(inputs=dropout, units=2)
+
+    return logits
+
+
+def model_fn(features, labels, mode, params, config):
+    """Model function for CNN."""
+    params = params or {}
+    learning_rate = params.get('learning_rate', 0.00001)
+
+    logits = architecture(features, mode)
+
+    predictions = {
+        # Generate predictions (for PREDICT and EVAL mode)
+        "classes": tf.argmax(input=logits, axis=1),
+        # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
+        # `logging_hook`.
+        "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
+    }
+    if mode == tf.estimator.ModeKeys.PREDICT:
+        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
+
+    # Calculate Loss (for both TRAIN and EVAL modes)
+    loss = tf.losses.sparse_softmax_cross_entropy(
+        logits=logits, labels=labels)
+
+    with tf.name_scope('train_metrics'):
+        # Create a tensor named train_loss for logging purposes
+        tf.summary.scalar('train_loss', loss)
+
+    # Configure the Training Op (for TRAIN mode)
+    if mode == tf.estimator.ModeKeys.TRAIN:
+        optimizer = tf.train.GradientDescentOptimizer(
+            learning_rate=learning_rate)
+        train_op = optimizer.minimize(
+            loss=loss,
+            global_step=tf.train.get_global_step())
+        return tf.estimator.EstimatorSpec(
+            mode=mode, loss=loss, train_op=train_op)
+
+    # Add evaluation metrics (for EVAL mode)
+    eval_metric_ops = {
+        "accuracy": tf.metrics.accuracy(
+            labels=labels, predictions=predictions["classes"])}
+    return tf.estimator.EstimatorSpec(
+        mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py
index cc09a91ef4f7f3d3830c3e018bc165b05df144bc..3226aa31d2fd87ca5d7f6011d3502e62431fde56 100755
--- a/bob/learn/tensorflow/network/__init__.py
+++ b/bob/learn/tensorflow/network/__init__.py
@@ -5,21 +5,24 @@ from .MLP import mlp
 from .Embedding import Embedding
 from .InceptionResnetV2 import inception_resnet_v2
 from .InceptionResnetV1 import inception_resnet_v1
+from . import SimpleCNN
 
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
-  """Says object was actually declared here, an not on the import module.
+    """Says object was actually declared here, an not on the import module.
 
-  Parameters:
+    Parameters:
 
-    *args: An iterable of objects to modify
+            *args: An iterable of objects to modify
 
-  Resolves `Sphinx referencing issues
-  <https://github.com/sphinx-doc/sphinx/issues/3048>`
-  """
+    Resolves `Sphinx referencing issues
+    <https://github.com/sphinx-doc/sphinx/issues/3048>`
+    """
+
+    for obj in args:
+        obj.__module__ = __name__
 
-  for obj in args: obj.__module__ = __name__
 
 __appropriate__(
     chopra,
@@ -28,5 +31,5 @@ __appropriate__(
     Embedding,
     mlp,
     )
-__all__ = [_ for _ in dir() if not _.startswith('_')]
 
+__all__ = [_ for _ in dir() if not _.startswith('_')]
diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py
index 10cc5c1969e1e559ead6596dff111fe4444766a4..aa5169935749e1022c84cca1a2e3427548a6689d 100644
--- a/bob/learn/tensorflow/script/db_to_tfrecords.py
+++ b/bob/learn/tensorflow/script/db_to_tfrecords.py
@@ -8,9 +8,9 @@ Usage:
   %(prog)s --version
 
 Arguments:
-  <config_files>  The config files. The config files are loaded in order and
-                  they need to have several objects inside totally. See below
-                  for explanation.
+  <config_files>  The configuration files. The configuration files are loaded
+                  in order and they need to have several objects inside
+                  totally. See below for explanation.
 
 Options:
   -h --help  show this help message and exit
@@ -21,7 +21,7 @@ Idiap:
 
   $ jman submit -i -q q1d -- bin/python %(prog)s <config_files>...
 
-The config files should have the following objects totally:
+The configuration files should have the following objects totally:
 
   ## Required objects:
 
@@ -83,6 +83,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 import random
+# import pkg_resources so that bob imports work properly:
+import pkg_resources
 
 import tensorflow as tf
 from bob.io.base import create_directories_safe
@@ -91,17 +93,18 @@ from bob.core.log import setup, set_verbosity_level
 logger = setup(__name__)
 
 
-def _bytes_feature(value):
+def bytes_feature(value):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
 
-def _int64_feature(value):
+def int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
 
-def write_a_sample(writer, data, label):
-    feature = {'train/data': _bytes_feature(data.tostring()),
-               'train/label': _int64_feature(label)}
+def write_a_sample(writer, data, label, feature=None):
+    if feature is None:
+        feature = {'train/data': bytes_feature(data.tostring()),
+                   'train/label': int64_feature(label)}
 
     example = tf.train.Example(features=tf.train.Features(feature=feature))
     writer.write(example.SerializeToString())
@@ -111,7 +114,6 @@ def main(argv=None):
     from docopt import docopt
     import os
     import sys
-    import pkg_resources
     docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
     version = pkg_resources.require('bob.learn.tensorflow')[0].version
     args = docopt(docs, argv=argv, version=version)
@@ -148,8 +150,7 @@ def main(argv=None):
                 logger.info('Processing file %d out of %d', i + 1, n_files)
 
                 path = f.make_path(data_dir, data_extension)
-                data = reader(path)
-                
+                data = reader(path)                
                 if data is None:
                   if allow_missing_files:
                       logger.debug("... Processing original data file '{0}' was not successful".format(path))
diff --git a/bob/learn/tensorflow/script/eval_generic.py b/bob/learn/tensorflow/script/eval_generic.py
new file mode 100644
index 0000000000000000000000000000000000000000..f29f756707c3c643711fbb6de9062dd3adb60aba
--- /dev/null
+++ b/bob/learn/tensorflow/script/eval_generic.py
@@ -0,0 +1,118 @@
+#!/usr/bin/env python
+
+"""Evaluates networks trained with tf.train.MonitoredTrainingSession
+
+Usage:
+  %(prog)s [options] <config_files>...
+  %(prog)s --help
+  %(prog)s --version
+
+Arguments:
+  <config_files>  The configuration files. The configuration files are loaded
+                  in order and they need to have several objects inside
+                  totally. See below for explanation.
+
+Options:
+  -h --help  show this help message and exit
+  --version  show version and exit
+
+The configuration files should have the following objects totally:
+
+  ## Required objects:
+
+  model_dir
+  model_fn
+  eval_input_fn
+
+  ## Optional objects:
+
+  eval_interval_secs
+  run_once
+  run_config
+  model_params
+  steps
+  hooks
+  name
+
+For an example configuration, please see:
+bob.learn.tensorflow/bob/learn/tensorflow/examples/mnist/mnist_config.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+# import pkg_resources so that bob imports work properly:
+import pkg_resources
+import os
+import time
+import six
+import tensorflow as tf
+from bob.bio.base.utils import read_config_file
+from ..utils.eval import get_global_step
+
+
+def main(argv=None):
+    from docopt import docopt
+    import sys
+    docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
+    version = pkg_resources.require('bob.learn.tensorflow')[0].version
+    args = docopt(docs, argv=argv, version=version)
+    config_files = args['<config_files>']
+    config = read_config_file(config_files)
+
+    model_dir = config.model_dir
+    model_fn = config.model_fn
+    eval_input_fn = config.eval_input_fn
+
+    eval_interval_secs = getattr(config, 'eval_interval_secs', 300)
+    run_once = getattr(config, 'run_once', False)
+    run_config = getattr(config, 'run_config', None)
+    model_params = getattr(config, 'model_params', None)
+    steps = getattr(config, 'steps', None)
+    hooks = getattr(config, 'hooks', None)
+    name = getattr(config, 'eval_name', None)
+
+    # Instantiate Estimator
+    nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
+                                params=model_params, config=run_config)
+    if name:
+        real_name = name + '_eval'
+    else:
+        real_name = 'eval'
+    evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated')
+    while True:
+        evaluated_steps = []
+        if os.path.exists(evaluated_file):
+            with open(evaluated_file) as f:
+                evaluated_steps = f.read().split()
+
+        ckpt = tf.train.get_checkpoint_state(nn.model_dir)
+        if (not ckpt) or (not ckpt.model_checkpoint_path):
+            time.sleep(eval_interval_secs)
+            continue
+
+        for checkpoint_path in ckpt.all_model_checkpoint_paths:
+            global_step = str(get_global_step(checkpoint_path))
+            if global_step in evaluated_steps:
+                continue
+
+            # Evaluate
+            evaluations = nn.evaluate(
+                input_fn=eval_input_fn,
+                steps=steps,
+                hooks=hooks,
+                checkpoint_path=checkpoint_path,
+                name=name,
+            )
+
+            print(', '.join('%s = %s' % (k, v)
+                            for k, v in sorted(six.iteritems(evaluations))))
+            sys.stdout.flush()
+            with open(evaluated_file, 'a') as f:
+                f.write('{}\n'.format(evaluations['global_step']))
+        if run_once:
+            break
+        time.sleep(eval_interval_secs)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/bob/learn/tensorflow/script/train_generic.py b/bob/learn/tensorflow/script/train_generic.py
new file mode 100644
index 0000000000000000000000000000000000000000..11f7d18a421b8c5ef48196ca254116722f8c5138
--- /dev/null
+++ b/bob/learn/tensorflow/script/train_generic.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python
+
+"""Trains networks using tf.train.MonitoredTrainingSession
+
+Usage:
+  %(prog)s [options] <config_files>...
+  %(prog)s --help
+  %(prog)s --version
+
+Arguments:
+  <config_files>  The configuration files. The configuration files are loaded
+                  in order and they need to have several objects inside
+                  totally. See below for explanation.
+
+Options:
+  -h --help  show this help message and exit
+  --version  show version and exit
+
+The configuration files should have the following objects totally:
+
+  ## Required objects:
+
+  model_fn
+  train_input_fn
+
+  ## Optional objects:
+
+  model_dir
+  run_config
+  model_params
+  hooks
+  steps
+  max_steps
+
+For an example configuration, please see:
+bob.learn.tensorflow/bob/learn/tensorflow/examples/mnist/mnist_config.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+# import pkg_resources so that bob imports work properly:
+import pkg_resources
+import tensorflow as tf
+from bob.bio.base.utils import read_config_file
+
+
+def main(argv=None):
+    from docopt import docopt
+    import os
+    import sys
+    docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
+    version = pkg_resources.require('bob.learn.tensorflow')[0].version
+    args = docopt(docs, argv=argv, version=version)
+    config_files = args['<config_files>']
+    config = read_config_file(config_files)
+
+    model_fn = config.model_fn
+    train_input_fn = config.train_input_fn
+
+    model_dir = getattr(config, 'model_dir', None)
+    run_config = getattr(config, 'run_config', None)
+    model_params = getattr(config, 'model_params', None)
+    hooks = getattr(config, 'hooks', None)
+    steps = getattr(config, 'steps', None)
+    max_steps = getattr(config, 'max_steps', None)
+
+    if run_config is None:
+        # by default create reproducible nets:
+        from bob.learn.tensorflow.utils.reproducible import session_conf
+        run_config = tf.estimator.RunConfig()
+        run_config.replace(session_config=session_conf)
+
+    # Instantiate Estimator
+    nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
+                                params=model_params, config=run_config)
+
+    # Train
+    nn.train(input_fn=train_input_fn, hooks=hooks, steps=steps,
+             max_steps=max_steps)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/bob/learn/tensorflow/test/test_db_to_tfrecords.py b/bob/learn/tensorflow/test/test_db_to_tfrecords.py
index e241159638e956bb8e1e8f4cbed32b6e78b2d032..64e9804c1366679a922d821e0a4a993106ab5579 100755
--- a/bob/learn/tensorflow/test/test_db_to_tfrecords.py
+++ b/bob/learn/tensorflow/test/test_db_to_tfrecords.py
@@ -19,7 +19,7 @@ def test_verify_and_tfrecords():
   with open(dummy_config) as f, open(config_path, 'w') as f2:
     f2.write(f.read().replace('TEST_DIR', test_dir))
 
-  parameters = [os.path.join(config_path)]
+  parameters = [config_path]
   try:
     verify(parameters)
     tfrecords(parameters)
diff --git a/bob/learn/tensorflow/test/test_estimator_scripts.py b/bob/learn/tensorflow/test/test_estimator_scripts.py
new file mode 100644
index 0000000000000000000000000000000000000000..38a8d89031f2da2eef92f3cbc73c835c52f2c083
--- /dev/null
+++ b/bob/learn/tensorflow/test/test_estimator_scripts.py
@@ -0,0 +1,134 @@
+from __future__ import print_function
+import os
+from tempfile import mkdtemp
+import shutil
+import logging
+logging.getLogger("tensorflow").setLevel(logging.WARNING)
+from bob.io.base.test_utils import datafile
+
+from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords
+from bob.bio.base.script.verify import main as verify
+from bob.learn.tensorflow.script.train_generic import main as train_generic
+from bob.learn.tensorflow.script.eval_generic import main as eval_generic
+
+dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__)
+CONFIG = '''
+import tensorflow as tf
+from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels, \
+    batch_data_and_labels
+
+model_dir = "%(model_dir)s"
+tfrecord_filenames = ['%(tfrecord_filenames)s']
+data_shape = (1, 112, 92)  # size of atnt images
+data_type = tf.uint8
+batch_size = 2
+epochs = 1
+learning_rate = 0.00001
+run_once = True
+
+
+def train_input_fn():
+    return shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type,
+                                   batch_size, epochs=epochs)
+
+def eval_input_fn():
+    return batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
+                                   batch_size, epochs=epochs)
+
+def architecture(images):
+    images = tf.cast(images, tf.float32)
+    logits = tf.reshape(images, [-1, 92 * 112])
+    logits = tf.layers.dense(inputs=logits, units=20,
+                             activation=tf.nn.relu)
+    return logits
+
+
+def model_fn(features, labels, mode, params, config):
+    logits = architecture(features)
+
+    predictions = {
+        # Generate predictions (for PREDICT and EVAL mode)
+        "classes": tf.argmax(input=logits, axis=1),
+        # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
+        # `logging_hook`.
+        "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
+    }
+    if mode == tf.estimator.ModeKeys.PREDICT:
+        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
+
+    # Calculate Loss (for both TRAIN and EVAL modes)
+    predictor = tf.nn.sparse_softmax_cross_entropy_with_logits(
+        logits=logits, labels=labels)
+    loss = tf.reduce_mean(predictor)
+
+    # Configure the Training Op (for TRAIN mode)
+    if mode == tf.estimator.ModeKeys.TRAIN:
+        global_step = tf.contrib.framework.get_or_create_global_step()
+        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+        train_op = optimizer.minimize(loss, global_step=global_step)
+        return tf.estimator.EstimatorSpec(mode=mode, loss=loss,
+                                          train_op=train_op)
+
+    # Add evaluation metrics (for EVAL mode)
+    eval_metric_ops = {
+        "accuracy": tf.metrics.accuracy(
+            labels=labels, predictions=predictions["classes"])}
+    return tf.estimator.EstimatorSpec(
+        mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
+'''
+
+
+def _create_tfrecord(test_dir):
+    config_path = os.path.join(test_dir, 'tfrecordconfig.py')
+    with open(dummy_tfrecord_config) as f, open(config_path, 'w') as f2:
+        f2.write(f.read().replace('TEST_DIR', test_dir))
+    verify([config_path])
+    tfrecords([config_path])
+    return os.path.join(test_dir, 'sub_directory', 'dev.tfrecords')
+
+
+def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord):
+    config = CONFIG % {'model_dir': model_dir,
+                       'tfrecord_filenames': dummy_tfrecord}
+    config_path = os.path.join(tmpdir, 'train_config.py')
+    with open(config_path, 'w') as f:
+        f.write(config)
+    train_generic([config_path])
+
+
+def _eval(tmpdir, model_dir, dummy_tfrecord):
+    config = CONFIG % {'model_dir': model_dir,
+                       'tfrecord_filenames': dummy_tfrecord}
+    config_path = os.path.join(tmpdir, 'eval_config.py')
+    with open(config_path, 'w') as f:
+        f.write(config)
+    eval_generic([config_path])
+
+
+def test_eval_once():
+    tmpdir = mkdtemp(prefix='bob_')
+    try:
+        model_dir = os.path.join(tmpdir, 'model_dir')
+        eval_dir = os.path.join(model_dir, 'eval')
+
+        print('\nCreating a dummy tfrecord')
+        dummy_tfrecord = _create_tfrecord(tmpdir)
+
+        print('Training a dummy network')
+        _create_checkpoint(tmpdir, model_dir, dummy_tfrecord)
+
+        print('Evaluating a dummy network')
+        _eval(tmpdir, model_dir, dummy_tfrecord)
+
+        evaluated_path = os.path.join(eval_dir, 'evaluated')
+        assert os.path.exists(evaluated_path), evaluated_path
+        with open(evaluated_path) as f:
+            doc = f.read()
+
+        assert '1' in doc, doc
+        assert '100' in doc, doc
+    finally:
+        try:
+            shutil.rmtree(tmpdir)
+        except Exception:
+            pass
diff --git a/bob/learn/tensorflow/utils/__init__.py b/bob/learn/tensorflow/utils/__init__.py
index e73a1da73733affd340604379e89fbecf765fac0..3fe013e8ccb40c8512359a1774c63f0513e18075 100755
--- a/bob/learn/tensorflow/utils/__init__.py
+++ b/bob/learn/tensorflow/utils/__init__.py
@@ -1,3 +1,6 @@
 from .util import *
 from .singleton import Singleton
-from .session import Session
\ No newline at end of file
+from .session import Session
+from . import hooks
+from . import eval
+from . import tfrecords
diff --git a/bob/learn/tensorflow/utils/eval.py b/bob/learn/tensorflow/utils/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf836f6e0d1f6742b10fee3d5fa9390d255f1dab
--- /dev/null
+++ b/bob/learn/tensorflow/utils/eval.py
@@ -0,0 +1,23 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import tensorflow as tf
+
+
+def get_global_step(path):
+    """Returns the global number associated with the model checkpoint path. The
+    checkpoint must have been saved with the
+    :any:`tf.train.MonitoredTrainingSession`.
+
+    Parameters
+    ----------
+    path : str
+        The path to model checkpoint, usually ckpt.model_checkpoint_path
+
+    Returns
+    -------
+    global_step : int
+        The global step number.
+    """
+    checkpoint_reader = tf.train.NewCheckpointReader(path)
+    return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
diff --git a/bob/learn/tensorflow/utils/hooks.py b/bob/learn/tensorflow/utils/hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1875e519a77a11a54f45fe6ab5726322173b17bf
--- /dev/null
+++ b/bob/learn/tensorflow/utils/hooks.py
@@ -0,0 +1,35 @@
+import tensorflow as tf
+import time
+from datetime import datetime
+
+
+class LoggerHook(tf.train.SessionRunHook):
+    """Logs loss and runtime."""
+
+    def __init__(self, loss, batch_size, log_frequency):
+        self.loss = loss
+        self.batch_size = batch_size
+        self.log_frequency = log_frequency
+
+    def begin(self):
+        self._step = -1
+        self._start_time = time.time()
+
+    def before_run(self, run_context):
+        self._step += 1
+        return tf.train.SessionRunArgs(self.loss)  # Asks for loss value.
+
+    def after_run(self, run_context, run_values):
+        if self._step % self.log_frequency == 0:
+            current_time = time.time()
+            duration = current_time - self._start_time
+            self._start_time = current_time
+
+            loss_value = run_values.results
+            examples_per_sec = self.log_frequency * self.batch_size / duration
+            sec_per_batch = float(duration / self.log_frequency)
+
+            format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+                          'sec/batch)')
+            print(format_str % (datetime.now(), self._step, loss_value,
+                                examples_per_sec, sec_per_batch))
diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py
new file mode 100644
index 0000000000000000000000000000000000000000..34cb4678258c75d40c889580bb30eff42c8f5242
--- /dev/null
+++ b/bob/learn/tensorflow/utils/reproducible.py
@@ -0,0 +1,37 @@
+import os
+import numpy as np
+import tensorflow as tf
+import random as rn
+# from tensorflow.contrib import keras
+
+# reproducible networks
+# The below is necessary in Python 3.2.3 onwards to
+# have reproducible behavior for certain hash-based operations.
+# See these references for further details:
+# https://docs.python.org/3.4/using/cmdline.html#envvar-PYTHONHASHSEED
+# https://github.com/fchollet/keras/issues/2280#issuecomment-306959926
+os.environ['PYTHONHASHSEED'] = '0'
+
+# The below is necessary for starting Numpy generated random numbers
+# in a well-defined initial state.
+np.random.seed(42)
+
+# The below is necessary for starting core Python generated random numbers
+# in a well-defined state.
+rn.seed(12345)
+
+# Force TensorFlow to use single thread.
+# Multiple threads are a potential source of
+# non-reproducible results.
+# For further details, see:
+# https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
+session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
+                              inter_op_parallelism_threads=1)
+
+# The below tf.set_random_seed() will make random number generation
+# in the TensorFlow backend have a well-defined initial state.
+# For further details, see:
+# https://www.tensorflow.org/api_docs/python/tf/set_random_seed
+tf.set_random_seed(1234)
+# sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
+# keras.backend.set_session(sess)
diff --git a/bob/learn/tensorflow/utils/sequences.py b/bob/learn/tensorflow/utils/sequences.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa16cb82c81a515af4a70796033df749a3cb9a9b
--- /dev/null
+++ b/bob/learn/tensorflow/utils/sequences.py
@@ -0,0 +1,143 @@
+from __future__ import division
+import numpy
+from keras.utils import Sequence
+# documentation imports
+from bob.dap.base.database import PadDatabase, PadFile
+from bob.bio.base.preprocessor import Preprocessor
+
+
+class PadSequence(Sequence):
+    """A data shuffler for bob.dap.base database interfaces.
+
+    Attributes
+    ----------
+    batch_size : int
+        The number of samples to return in every batch.
+    files : list of :any:`PadFile`
+        List of file objects for a particular group and protocol.
+    labels : list of bool
+        List of labels for the files. ``True`` if bona-fide, ``False`` if
+        attack.
+    preprocessor : :any:`Preprocessor`
+        The preprocessor to be used to load and process the data.
+    """
+
+    def __init__(self, files, labels, batch_size, preprocessor,
+                 original_directory, original_extension):
+        super(PadSequence, self).__init__()
+        self.files = files
+        self.labels = labels
+        self.batch_size = int(batch_size)
+        self.preprocessor = preprocessor
+        self.original_directory = original_directory
+        self.original_extension = original_extension
+
+    def __len__(self):
+        """Number of batch in the Sequence.
+
+        Returns
+        -------
+        int
+            The number of batches in the Sequence.
+        """
+        return int(numpy.ceil(len(self.files) / self.batch_size))
+
+    def __getitem__(self, idx):
+        files = self.files[idx * self.batch_size:(idx + 1) * self.batch_size]
+        labels = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
+        return self.load_batch(files, labels)
+
+    def load_batch(self, files, labels):
+        """Loads a batch of files and processes them.
+
+        Parameters
+        ----------
+        files : list of :any:`PadFile`
+            List of files to load.
+        labels : list of bool
+            List of labels corresponding to the files.
+
+        Returns
+        -------
+        tuple of :any:`numpy.array`
+            A tuple of (x, y): the data and their targets.
+        """
+        data, targets = [], []
+        for file_object, target in zip(files, labels):
+            loaded_data = self.preprocessor.read_original_data(
+                file_object,
+                self.original_directory,
+                self.original_extension)
+            preprocessed_data = self.preprocessor(loaded_data)
+            data.append(preprocessed_data)
+            targets.append(target)
+        return numpy.array(data), numpy.array(targets)
+
+    def on_epoch_end(self):
+        pass
+
+
+def shuffle_data(files, labels):
+    indexes = numpy.arange(len(files))
+    numpy.random.shuffle(indexes)
+    return [files[i] for i in indexes], [labels[i] for i in indexes]
+
+
+def get_pad_files_labels(database, groups):
+    """Returns the pad files and their labels.
+
+    Parameters
+    ----------
+    database : :any:`PadDatabase`
+        The database to be used. The database should have a proper
+        ``database.protocol`` attribute.
+    groups : str
+        The group to be used to return the data. One of ('world', 'dev',
+        'eval'). 'world' means training data and 'dev' means validation data.
+
+    Returns
+    -------
+    tuple
+        A tuple of (files, labels) for that particular group and protocol.
+    """
+    files = database.samples(
+        groups=groups, protocol=database.protocol)
+    labels = ((f.attack_type is None) for f in files)
+    labels = numpy.fromiter(labels, bool, len(files))
+    return files, labels
+
+
+def get_pad_sequences(database, preprocessor, batch_size,
+                      groups=('world', 'dev', 'eval'), shuffle=False,
+                      limit=None):
+    """Returns a list of :any:`Sequence` objects for the database.
+
+    Parameters
+    ----------
+    database : :any:`PadDatabase`
+        The database to be used. The database should have a proper
+        ``database.protocol`` attribute.
+    preprocessor : :any:`Preprocessor`
+        The preprocessor to be used to load and process the data.
+    batch_size : int
+        The number of samples to return in every batch.
+    groups : str
+        The group to be used to return the data. One of ('world', 'dev',
+        'eval'). 'world' means training data and 'dev' means validation data.
+
+    Returns
+    -------
+    list of :any:`Sequence`
+        The requested sequences to be used.
+    """
+    seqs = []
+    for grp in groups:
+        files, labels = get_pad_files_labels(database, grp)
+        if shuffle:
+            files, labels = shuffle_data(files, labels)
+        if limit is not None:
+            files, labels = files[:limit], labels[:limit]
+        seqs.append(PadSequence(files, labels, batch_size, preprocessor,
+                                database.original_directory,
+                                database.original_extension))
+    return seqs
diff --git a/bob/learn/tensorflow/utils/tfrecords.py b/bob/learn/tensorflow/utils/tfrecords.py
new file mode 100644
index 0000000000000000000000000000000000000000..48da0740577c2a64e4e6f59b26dac959e0a0678f
--- /dev/null
+++ b/bob/learn/tensorflow/utils/tfrecords.py
@@ -0,0 +1,60 @@
+from functools import partial
+import tensorflow as tf
+
+
+DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string),
+                   'train/label': tf.FixedLenFeature([], tf.int64)}
+
+
+def example_parser(serialized_example, feature, data_shape, data_type):
+    """Parses a single tf.Example into image and label tensors."""
+    # Decode the record read by the reader
+    features = tf.parse_single_example(serialized_example, features=feature)
+    # Convert the image data from string back to the numbers
+    image = tf.decode_raw(features['train/data'], data_type)
+    # Cast label data into int64
+    label = tf.cast(features['train/label'], tf.int64)
+    # Reshape image data into the original shape
+    image = tf.reshape(image, data_shape)
+    return image, label
+
+
+def read_and_decode(filename_queue, data_shape, data_type=tf.float32,
+                    feature=None):
+    if feature is None:
+        feature = DEFAULT_FEATURE
+    # Define a reader and read the next record
+    reader = tf.TFRecordReader()
+    _, serialized_example = reader.read(filename_queue)
+    return example_parser(serialized_example, feature, data_shape, data_type)
+
+
+def create_dataset_from_records(tfrecord_filenames, data_shape, data_type,
+                                feature=None):
+    if feature is None:
+        feature = DEFAULT_FEATURE
+    dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames)
+    parser = partial(example_parser, feature=feature, data_shape=data_shape,
+                     data_type=data_type)
+    dataset = dataset.map(parser)
+    return dataset
+
+
+def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type,
+                            batch_size, epochs=None, buffer_size=10**3):
+    dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
+                                          data_type)
+    dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
+
+    datas, labels = dataset.make_one_shot_iterator().get_next()
+    return datas, labels
+
+
+def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
+                          batch_size, epochs=1):
+    dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
+                                          data_type)
+    dataset = dataset.batch(batch_size).repeat(epochs)
+
+    datas, labels = dataset.make_one_shot_iterator().get_next()
+    return datas, labels
diff --git a/setup.py b/setup.py
index 493222372fc3ca1dcd2f34db879ebd3277e7bd4a..6a0b4dbdcf9cca07279ef5a0157348188e85d9cf 100755
--- a/setup.py
+++ b/setup.py
@@ -51,7 +51,9 @@ setup(
             'train.py = bob.learn.tensorflow.script.train:main',
             'bob_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main',
             'load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main',
-            'lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main'
+            'lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main',
+            'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main',
+            'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main',
         ],
 
     },