diff --git a/bob/learn/tensorflow/examples/mnist/mnist_config.py b/bob/learn/tensorflow/examples/mnist/mnist_config.py
index 6227fcd462822776320db3c4e24a646f3ef2c721..9991e218e4bf454d28fadc2e1c3d6eaa52c89067 100644
--- a/bob/learn/tensorflow/examples/mnist/mnist_config.py
+++ b/bob/learn/tensorflow/examples/mnist/mnist_config.py
@@ -17,16 +17,15 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from bob.learn.tensorflow.utils.reproducible import session_conf
+# create reproducible nets:
+from bob.learn.tensorflow.utils.reproducible import run_config
 import tensorflow as tf
+from bob.db.mnist import Database
 
 model_dir = '/tmp/mnist_model'
 train_tfrecords = ['/tmp/mnist_data/train.tfrecords']
 eval_tfrecords = ['/tmp/mnist_data/test.tfrecords']
 
-# by default create reproducible nets:
-run_config = tf.estimator.RunConfig()
-run_config = run_config.replace(session_config=session_conf)
 run_config = run_config.replace(keep_checkpoint_max=10**3)
 run_config = run_config.replace(save_checkpoints_secs=60)
 
@@ -39,22 +38,27 @@ def input_fn(mode, batch_size=1):
         features = tf.parse_single_example(
             serialized_example,
             features={
-                'image_raw': tf.FixedLenFeature([], tf.string),
+                'data': tf.FixedLenFeature([], tf.string),
                 'label': tf.FixedLenFeature([], tf.int64),
+                'key': tf.FixedLenFeature([], tf.string),
             })
-        image = tf.decode_raw(features['image_raw'], tf.uint8)
+        image = tf.decode_raw(features['data'], tf.uint8)
         image.set_shape([28 * 28])
 
         # Normalize the values of the image from the range
         # [0, 255] to [-0.5, 0.5]
         image = tf.cast(image, tf.float32) / 255 - 0.5
         label = tf.cast(features['label'], tf.int32)
-        return image, tf.one_hot(label, 10)
+
+        key = tf.cast(features['key'], tf.string)
+        return image, tf.one_hot(label, 10), key
 
     if mode == tf.estimator.ModeKeys.TRAIN:
         tfrecords_files = train_tfrecords
+    elif mode == tf.estimator.ModeKeys.EVAL:
+        tfrecords_files = eval_tfrecords
     else:
-        assert mode == tf.estimator.ModeKeys.EVAL, 'invalid mode'
+        assert mode == tf.estimator.ModeKeys.PREDICT, 'invalid mode'
         tfrecords_files = eval_tfrecords
 
     for tfrecords_file in tfrecords_files:
@@ -73,9 +77,9 @@ def input_fn(mode, batch_size=1):
     dataset = dataset.map(
         example_parser, num_threads=1, output_buffer_size=batch_size)
     dataset = dataset.batch(batch_size)
-    images, labels = dataset.make_one_shot_iterator().get_next()
+    images, labels, keys = dataset.make_one_shot_iterator().get_next()
 
-    return images, labels
+    return {'images': images, 'keys': keys}, labels
 
 
 def train_input_fn():
@@ -86,6 +90,10 @@ def eval_input_fn():
     return input_fn(tf.estimator.ModeKeys.EVAL)
 
 
+def predict_input_fn():
+    return input_fn(tf.estimator.ModeKeys.PREDICT)
+
+
 def mnist_model(inputs, mode):
     """Takes the MNIST inputs and mode and outputs a tensor of logits."""
     # Input Layer
@@ -164,13 +172,16 @@ def mnist_model(inputs, mode):
     return logits
 
 
-def model_fn(features, labels, mode):
+def model_fn(features, labels=None, mode=tf.estimator.ModeKeys.TRAIN):
     """Model function for MNIST."""
+    keys = features['keys']
+    features = features['images']
     logits = mnist_model(features, mode)
 
     predictions = {
         'classes': tf.argmax(input=logits, axis=1),
-        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
+        'probabilities': tf.nn.softmax(logits, name='softmax_tensor'),
+        'keys': keys,
     }
 
     if mode == tf.estimator.ModeKeys.PREDICT:
@@ -202,3 +213,22 @@ def model_fn(features, labels, mode):
         loss=loss,
         train_op=train_op,
         eval_metric_ops=metrics)
+
+
+estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
+                                   params=None, config=run_config)
+
+
+output = train_tfrecords[0]
+db = Database()
+data, labels = db.data(groups='train')
+
+# output = eval_tfrecords[0]
+# db = Database()
+# data, labels = db.data(groups='test')
+
+samples = zip(data, labels, (str(i) for i in range(len(data))))
+
+
+def reader(sample):
+    return sample
diff --git a/bob/learn/tensorflow/examples/mnist/tfrecords.py b/bob/learn/tensorflow/examples/mnist/tfrecords.py
new file mode 100644
index 0000000000000000000000000000000000000000..657d5c97f992b05b35d54b70877c4e5b2594ca29
--- /dev/null
+++ b/bob/learn/tensorflow/examples/mnist/tfrecords.py
@@ -0,0 +1,61 @@
+# Required objects:
+
+# you need a database object that inherits from
+# bob.bio.base.database.BioDatabase (PAD dbs work too)
+database = Database()
+
+# the directory pointing to where the processed data is:
+data_dir = '/idiap/temp/user/database_name/sub_directory/preprocessed'
+
+# the directory to save the tfrecords in:
+output_dir = '/idiap/temp/user/database_name/sub_directory'
+
+
+# A function that converts a BioFile or a PadFile to a label:
+# Example for PAD
+def file_to_label(f):
+    return f.attack_type is None
+
+
+# Example for Bio (You may want to run this script for groups=['world'] only
+# in biometric recognition experiments.)
+CLIENT_IDS = (str(f.client_id) for f in db.all_files(groups=groups))
+CLIENT_IDS = list(set(CLIENT_IDS))
+CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
+
+
+def file_to_label(f):
+    return CLIENT_IDS[str(f.client_id)]
+
+# Optional objects:
+
+
+# The groups that you want to create tfrecords for. It should be a list of
+# 'world' ('train' in bob.pad.base), 'dev', and 'eval' values. [default:
+# 'world']
+groups = ['world']
+
+# you need a reader function that reads the preprocessed files. [default:
+# bob.bio.base.utils.load]
+reader = Preprocessor().read_data
+reader = Extractor().read_feature
+# or
+from bob.bio.base.utils import load as reader
+# or a reader that casts images to uint8:
+
+
+def reader(path):
+    data = bob.bio.base.utils.load(path)
+    return data.astype("uint8")
+
+
+# extension of the preprocessed files. [default: '.hdf5']
+data_extension = '.hdf5'
+
+# Shuffle the files before writing them into a tfrecords. [default: False]
+shuffle = True
+
+# Whether the each file contains one sample or more. [default: True] If
+# this is False, the loaded samples from a file are iterated over and each
+# of them is saved as an independent feature.
+one_file_one_sample = True
diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py
index aa5169935749e1022c84cca1a2e3427548a6689d..dc1cd120b6db31c0857415d44de7964bf722d53d 100644
--- a/bob/learn/tensorflow/script/db_to_tfrecords.py
+++ b/bob/learn/tensorflow/script/db_to_tfrecords.py
@@ -3,80 +3,86 @@
 """Converts Bio and PAD datasets to TFRecords file formats.
 
 Usage:
-  %(prog)s <config_files>... [--allow-missing-files]
-  %(prog)s --help
-  %(prog)s --version
+    %(prog)s [-v...] [options] <config_files>...
+    %(prog)s --help
+    %(prog)s --version
 
 Arguments:
-  <config_files>  The configuration files. The configuration files are loaded
-                  in order and they need to have several objects inside
-                  totally. See below for explanation.
+    <config_files>              The configuration files. The configuration
+                                files are loaded in order and they need to have
+                                several objects inside totally. See below for
+                                explanation.
 
 Options:
-  -h --help  show this help message and exit
-  --version  show version and exit
+    -h --help                   Show this help message and exit
+    --version                   Show version and exit
+    -o PATH, --output PATH      Name of the output file.
+    --shuffle                   If provided, it will shuffle the samples.
+    --allow-failures            If provided, the samples which fail to load are
+                                ignored.
+    --multiple-samples          If provided, it means that the data provided by
+                                reader contains multiple samples with same
+                                label and path.
+    -v, --verbose               Increases the output verbosity level
 
 The best way to use this script is to send it to the io-big queue if you are at
 Idiap:
 
-  $ jman submit -i -q q1d -- bin/python %(prog)s <config_files>...
+    $ jman submit -i -q q1d -- %(prog)s <config_files>...
 
-The configuration files should have the following objects totally:
+The configuration files should have the following objects totally::
 
-  ## Required objects:
+    # Required objects:
+    samples : a list of all samples that you want to write in the tfrecords
+              file. Whatever is inside this list is passed to the reader.
+    reader  : a function with the signature of
+              `data, label, key = reader(sample)` which takes a sample and
+              returns the loaded data, the label of the data, and a key which
+              is unique for every sample.
 
-  # you need a database object that inherits from
-  # bob.bio.base.database.BioDatabase (PAD dbs work too)
-  database = Database()
+You can also provide the command line options in the configuration file too.
+It is needed to replace "-" with "_" when they are in the configuration file.
+The ones provided by command line overwrite the values of the config file.
 
-  # the directory pointing to where the processed data is:
-  data_dir = '/idiap/temp/user/database_name/sub_directory/preprocessed'
+An example for mnist would be::
 
-  # the directory to save the tfrecords in:
-  output_dir = '/idiap/temp/user/database_name/sub_directory'
+    from bob.db.mnist import Database
+    db = Database()
+    data, labels = db.data(groups='train')
 
-  # A function that converts a BioFile or a PadFile to a label:
-  # Example for PAD
-  def file_to_label(f):
-      return f.attack_type is None
+    samples = zip(data, labels, (str(i) for i in range(len(data))))
 
-  # Example for Bio (You may want to run this script for groups=['world'] only
-  # in biometric recognition experiments.)
-  CLIENT_IDS = (str(f.client_id) for f in db.all_files(groups=groups))
-  CLIENT_IDS = list(set(CLIENT_IDS))
-  CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
+    def reader(sample):
+        return sample
 
-  def file_to_label(f):
-      return CLIENT_IDS[str(f.client_id)]
+    allow_failures = True
+    output = '/tmp/mnist_train.tfrecords'
+    shuffle = True
 
-  ## Optional objects:
+An example for bob.bio.base would be::
 
-  # The groups that you want to create tfrecords for. It should be a list of
-  # 'world' ('train' in bob.pad.base), 'dev', and 'eval' values. [default:
-  # 'world']
-  groups = ['world']
+    from bob.bio.base.test.dummy.database import database
+    from bob.bio.base.test.dummy.preprocessor import preprocessor
 
-  # you need a reader function that reads the preprocessed files. [default:
-  # bob.bio.base.utils.load]
-  reader = Preprocessor().read_data
-  reader = Extractor().read_feature
-  # or
-  from bob.bio.base.utils import load as reader
-  # or a reader that casts images to uint8:
-  def reader(path):
-    data = bob.bio.base.utils.load(path)
-    return data.astype("uint8")
+    groups = 'dev'
 
-  # extension of the preprocessed files. [default: '.hdf5']
-  data_extension = '.hdf5'
+    samples = database.all_files(groups=groups)
 
-  # Shuffle the files before writing them into a tfrecords. [default: False]
-  shuffle = True
+    CLIENT_IDS = (str(f.client_id) for f in database.all_files(groups=groups))
+    CLIENT_IDS = list(set(CLIENT_IDS))
+    CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
 
-  # Whether the each file contains one sample or more. [default: True] If
-  # this is False, the loaded samples from a file are iterated over and each
-  # of them is saved as an independent feature.
-  one_file_one_sample = True
+
+    def file_to_label(f):
+        return CLIENT_IDS[str(f.client_id)]
+
+
+    def reader(biofile):
+        data = preprocessor.read_original_data(
+            biofile, database.original_directory, database.original_extension)
+        label = file_to_label(biofile)
+        key = biofile.path
+        return (data, label, key)
 """
 
 from __future__ import absolute_import
@@ -85,10 +91,12 @@ from __future__ import print_function
 import random
 # import pkg_resources so that bob imports work properly:
 import pkg_resources
-
+import six
 import tensorflow as tf
 from bob.io.base import create_directories_safe
-from bob.bio.base.utils import load, read_config_file
+from bob.bio.base.utils import read_config_file
+from bob.learn.tensorflow.utils.commandline import \
+    get_from_config_or_commandline
 from bob.core.log import setup, set_verbosity_level
 logger = setup(__name__)
 
@@ -101,10 +109,11 @@ def int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
 
-def write_a_sample(writer, data, label, feature=None):
+def write_a_sample(writer, data, label, key, feature=None):
     if feature is None:
-        feature = {'train/data': bytes_feature(data.tostring()),
-                   'train/label': int64_feature(label)}
+        feature = {'data': bytes_feature(data.tostring()),
+                   'label': int64_feature(label),
+                   'key': bytes_feature(key)}
 
     example = tf.train.Example(features=tf.train.Features(feature=feature))
     writer.write(example.SerializeToString())
@@ -116,55 +125,62 @@ def main(argv=None):
     import sys
     docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
     version = pkg_resources.require('bob.learn.tensorflow')[0].version
+    defaults = docopt(docs, argv=[""])
     args = docopt(docs, argv=argv, version=version)
     config_files = args['<config_files>']
     config = read_config_file(config_files)
-    allow_missing_files = args['--allow-missing-files']
+
+    # optional arguments
+    verbosity = get_from_config_or_commandline(
+        config, 'verbose', args, defaults)
+    allow_failures = get_from_config_or_commandline(
+        config, 'allow_failures', args, defaults)
+    multiple_samples = get_from_config_or_commandline(
+        config, 'multiple_samples', args, defaults)
+    shuffle = get_from_config_or_commandline(
+        config, 'shuffle', args, defaults)
 
     # Sets-up logging
-    verbosity = getattr(config, 'verbose', 0)
     set_verbosity_level(logger, verbosity)
 
-    database = config.database
-    data_dir, output_dir = config.data_dir, config.output_dir
-    file_to_label = config.file_to_label
+    # required arguments
+    samples = config.samples
+    reader = config.reader
+    output = get_from_config_or_commandline(
+        config, 'output', args, defaults, False)
 
-    reader = getattr(config, 'reader', load)
-    groups = getattr(config, 'groups', ['world'])
-    data_extension = getattr(config, 'data_extension', '.hdf5')
-    shuffle = getattr(config, 'shuffle', False)
-    one_file_one_sample = getattr(config, 'one_file_one_sample', True)
+    if not output.endswith(".tfrecords"):
+        output += ".tfrecords"
 
-    create_directories_safe(output_dir)
-    if not isinstance(groups, (list, tuple)):
-        groups = [groups]
+    create_directories_safe(os.path.dirname(output))
 
-    for group in groups:
-        output_file = os.path.join(output_dir, '{}.tfrecords'.format(group))
-        files = database.all_files(groups=group)
+    n_samples = len(samples)
+    sample_counter = 0
+    with tf.python_io.TFRecordWriter(output) as writer:
         if shuffle:
-            random.shuffle(files)
-        n_files = len(files)
-        with tf.python_io.TFRecordWriter(output_file) as writer:
-            for i, f in enumerate(files):
-                logger.info('Processing file %d out of %d', i + 1, n_files)
-
-                path = f.make_path(data_dir, data_extension)
-                data = reader(path)                
-                if data is None:
-                  if allow_missing_files:
-                      logger.debug("... Processing original data file '{0}' was not successful".format(path))
-                      continue
-                  else:
-                      raise RuntimeError("Preprocessing of file '{0}' was not successful".format(path))
-                
-                label = file_to_label(f)
-
-                if one_file_one_sample:
-                    write_a_sample(writer, data, label)
+            random.shuffle(samples)
+        for i, sample in enumerate(samples):
+            logger.info('Processing file %d out of %d', i + 1, n_samples)
+
+            data, label, key = reader(sample)
+
+            if data is None:
+                if allow_failures:
+                    logger.debug("... Skipping `{0}`.".format(sample))
+                    continue
                 else:
-                    for sample in data:
-                        write_a_sample(writer, sample, label)
+                    raise RuntimeError(
+                        "Reading failed for `{0}`".format(sample))
+
+            if multiple_samples:
+                for sample in data:
+                    write_a_sample(writer, sample, label, key)
+                    sample_counter += 1
+            else:
+                write_a_sample(writer, data, label, key)
+                sample_counter += 1
+
+    print("Wrote {} samples into the tfrecords file.".format(sample_counter))
 
 
 if __name__ == '__main__':
diff --git a/bob/learn/tensorflow/script/eval_generic.py b/bob/learn/tensorflow/script/eval_generic.py
index f29f756707c3c643711fbb6de9062dd3adb60aba..e8432aa3cb946e29a460dbace75b25908032784e 100644
--- a/bob/learn/tensorflow/script/eval_generic.py
+++ b/bob/learn/tensorflow/script/eval_generic.py
@@ -63,7 +63,7 @@ def main(argv=None):
     model_fn = config.model_fn
     eval_input_fn = config.eval_input_fn
 
-    eval_interval_secs = getattr(config, 'eval_interval_secs', 300)
+    eval_interval_secs = getattr(config, 'eval_interval_secs', 60)
     run_once = getattr(config, 'run_once', False)
     run_config = getattr(config, 'run_config', None)
     model_params = getattr(config, 'model_params', None)
@@ -75,7 +75,7 @@ def main(argv=None):
     nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
                                 params=model_params, config=run_config)
     if name:
-        real_name = name + '_eval'
+        real_name = 'eval_' + name
     else:
         real_name = 'eval'
     evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated')
@@ -91,7 +91,12 @@ def main(argv=None):
             continue
 
         for checkpoint_path in ckpt.all_model_checkpoint_paths:
-            global_step = str(get_global_step(checkpoint_path))
+            try:
+                global_step = str(get_global_step(checkpoint_path))
+            except Exception:
+                print('Failed to find global_step for checkpoint_path {}, '
+                      'skipping ...'.format(checkpoint_path))
+                continue
             if global_step in evaluated_steps:
                 continue
 
diff --git a/bob/learn/tensorflow/script/predict_generic.py b/bob/learn/tensorflow/script/predict_generic.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e2d874bcc8b0e419e60211ac2d335c56427b2aa
--- /dev/null
+++ b/bob/learn/tensorflow/script/predict_generic.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python
+
+"""Returns predictions of networks trained with
+tf.train.MonitoredTrainingSession
+
+Usage:
+    %(prog)s [-v...] [-k KEY]... [options] <config_files>...
+    %(prog)s --help
+    %(prog)s --version
+
+Arguments:
+    <config_files>                  The configuration files. The configuration
+                                    files are loaded in order and they need to
+                                    have several objects inside totally. See
+                                    below for explanation.
+
+Options:
+    -h --help                       Show this help message and exit
+    --version                       Show version and exit
+    -o PATH, --output-dir PATH      Name of the output file.
+    -k KEY, --predict-keys KEY      List of `str`, name of the keys to predict.
+                                    It is used if the
+                                    `EstimatorSpec.predictions` is a `dict`. If
+                                    `predict_keys` is used then rest of the
+                                    predictions will be filtered from the
+                                    dictionary. If `None`, returns all.
+    --checkpoint-path=<path>        Path of a specific checkpoint to predict.
+                                    If `None`, the latest checkpoint in
+                                    `model_dir` is used.
+    -v, --verbose                   Increases the output verbosity level
+
+The configuration files should have the following objects totally:
+
+    # Required objects:
+
+    estimator
+    predict_input_fn
+
+    # Optional objects:
+
+    hooks
+
+For an example configuration, please see:
+bob.learn.tensorflow/bob/learn/tensorflow/examples/mnist/mnist_config.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+# import pkg_resources so that bob imports work properly:
+import pkg_resources
+import os
+from multiprocessing import Pool
+from collections import defaultdict
+import numpy as np
+from bob.io.base import create_directories_safe
+from bob.bio.base.utils import read_config_file, save
+from bob.learn.tensorflow.utils.commandline import \
+    get_from_config_or_commandline
+from bob.core.log import setup, set_verbosity_level
+logger = setup(__name__)
+
+
+def save_predictions(pool, output_dir, key, pred_buffer):
+    outpath = os.path.join(output_dir, key + '.hdf5')
+    create_directories_safe(os.path.dirname(outpath))
+    pool.apply_async(save, (np.mean(pred_buffer[key], axis=0), outpath))
+
+
+def main(argv=None):
+    from docopt import docopt
+    import sys
+    docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
+    version = pkg_resources.require('bob.learn.tensorflow')[0].version
+    defaults = docopt(docs, argv=[""])
+    args = docopt(docs, argv=argv, version=version)
+    config_files = args['<config_files>']
+    config = read_config_file(config_files)
+
+    # optional arguments
+    verbosity = get_from_config_or_commandline(
+        config, 'verbose', args, defaults)
+    predict_keys = get_from_config_or_commandline(
+        config, 'predict_keys', args, defaults)
+    checkpoint_path = get_from_config_or_commandline(
+        config, 'checkpoint_path', args, defaults)
+    hooks = getattr(config, 'hooks', None)
+
+    # Sets-up logging
+    set_verbosity_level(logger, verbosity)
+
+    # required arguments
+    estimator = config.estimator
+    predict_input_fn = config.predict_input_fn
+    output_dir = get_from_config_or_commandline(
+        config, 'output_dir', args, defaults, False)
+
+    predictions = estimator.predict(
+        predict_input_fn,
+        predict_keys=predict_keys,
+        hooks=hooks,
+        checkpoint_path=checkpoint_path,
+    )
+
+    pool = Pool()
+    try:
+        pred_buffer = defaultdict(list)
+        for i, pred in enumerate(predictions):
+            key = pred['keys']
+            prob = pred.get('probabilities', pred.get('embeddings'))
+            pred_buffer[key].append(prob)
+            if i == 0:
+                last_key = key
+            if last_key == key:
+                continue
+            else:
+                save_predictions(pool, output_dir, last_key, pred_buffer)
+                last_key = key
+        # else below is for the for loop
+        else:
+            save_predictions(pool, output_dir, key, pred_buffer)
+    finally:
+        pool.close()
+        pool.join()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/bob/learn/tensorflow/script/train_generic.py b/bob/learn/tensorflow/script/train_generic.py
index 11f7d18a421b8c5ef48196ca254116722f8c5138..e60e4917baf349422a83360220faf2722aaed66f 100644
--- a/bob/learn/tensorflow/script/train_generic.py
+++ b/bob/learn/tensorflow/script/train_generic.py
@@ -66,9 +66,7 @@ def main(argv=None):
 
     if run_config is None:
         # by default create reproducible nets:
-        from bob.learn.tensorflow.utils.reproducible import session_conf
-        run_config = tf.estimator.RunConfig()
-        run_config.replace(session_config=session_conf)
+        from bob.learn.tensorflow.utils.reproducible import run_config
 
     # Instantiate Estimator
     nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
diff --git a/bob/learn/tensorflow/test/data/dummy_verify_config.py b/bob/learn/tensorflow/test/data/dummy_verify_config.py
index 448da020af64ac1a69b13be3e0930a283b83ea57..0b2e4661e84899536b8ede7cda576c7006124aea 100755
--- a/bob/learn/tensorflow/test/data/dummy_verify_config.py
+++ b/bob/learn/tensorflow/test/data/dummy_verify_config.py
@@ -1,15 +1,9 @@
-import os
 from bob.bio.base.test.dummy.database import database
-preprocessor = extractor = algorithm = 'dummy'
-groups = ['dev']
+from bob.bio.base.test.dummy.preprocessor import preprocessor
 
-temp_directory = result_directory = 'TEST_DIR'
-sub_directory = 'sub_directory'
+groups = 'dev'
 
-data_dir = os.path.join('TEST_DIR', sub_directory, 'preprocessed')
-
-# the directory to save the tfrecords in:
-output_dir = os.path.join('TEST_DIR', sub_directory)
+files = database.all_files(groups=groups)
 
 CLIENT_IDS = (str(f.client_id) for f in database.all_files(groups=groups))
 CLIENT_IDS = list(set(CLIENT_IDS))
@@ -18,3 +12,11 @@ CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
 
 def file_to_label(f):
     return CLIENT_IDS[str(f.client_id)]
+
+
+def reader(biofile):
+    data = preprocessor.read_original_data(
+        biofile, database.original_directory, database.original_extension)
+    label = file_to_label(biofile)
+    key = biofile.path
+    return (data, label, key)
diff --git a/bob/learn/tensorflow/utils/__init__.py b/bob/learn/tensorflow/utils/__init__.py
index 3fe013e8ccb40c8512359a1774c63f0513e18075..e655e39675e37cf2b165e5a68769c42f81c6574f 100755
--- a/bob/learn/tensorflow/utils/__init__.py
+++ b/bob/learn/tensorflow/utils/__init__.py
@@ -4,3 +4,4 @@ from .session import Session
 from . import hooks
 from . import eval
 from . import tfrecords
+from . import commandline
diff --git a/bob/learn/tensorflow/utils/commandline.py b/bob/learn/tensorflow/utils/commandline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fdb3f2fa48f4c8664df2acf3188b24ae9e2ba3e
--- /dev/null
+++ b/bob/learn/tensorflow/utils/commandline.py
@@ -0,0 +1,56 @@
+def get_from_config_or_commandline(config, keyword, args, defaults,
+                                   default_is_valid=True):
+    """Takes the value from command line, config file, and default value with
+    this precedence.
+
+    Only several command line options can be used with this function:
+    - boolean flags
+    - repeating flags (like --verbose)
+    - options where the user will never provide the default value through
+      command line. For example when [default: None]
+
+    Parameters
+    ----------
+    config : object
+        The loaded config files.
+    keyword : str
+        The keyword to load from the config file or through command line.
+    args : dict
+        The arguments parsed with docopt.
+    defaults : dict
+        The arguments parsed with docopt when ``argv=[]``.
+    default_is_valid : bool, optional
+        If False, will raise an exception if the final parsed value is the
+        default value.
+
+    Returns
+    -------
+    object
+        The bool or integer value of the corresponding keyword.
+
+    Example
+    -------
+    >>> from bob.bio.base.utils import read_config_file
+    >>> defaults = docopt(docs, argv=[""])
+    >>> args = docopt(docs, argv=argv)
+    >>> config_files = args['<config_files>']
+    >>> config = read_config_file(config_files)
+
+    >>> verbosity = get_from_config_or_commandline(config, 'verbose', args,
+    ...                                            defaults)
+
+    """
+    arg_keyword = '--' + keyword.replace('_', '-')
+
+    # load from config first
+    value = getattr(config, keyword, defaults[arg_keyword])
+
+    # override it if provided by command line arguments
+    if args[arg_keyword] != defaults[arg_keyword]:
+        value = args[arg_keyword]
+
+    if not default_is_valid and value == defaults[arg_keyword]:
+        raise ValueError(
+            "The value provided for {} is not valid.".format(keyword))
+
+    return value
diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py
index 34cb4678258c75d40c889580bb30eff42c8f5242..0d9824ce6541561194ac748563807bdb39ac2beb 100644
--- a/bob/learn/tensorflow/utils/reproducible.py
+++ b/bob/learn/tensorflow/utils/reproducible.py
@@ -25,13 +25,18 @@ rn.seed(12345)
 # non-reproducible results.
 # For further details, see:
 # https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
-session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
-                              inter_op_parallelism_threads=1)
+session_config = tf.ConfigProto(intra_op_parallelism_threads=1,
+                                inter_op_parallelism_threads=1)
 
 # The below tf.set_random_seed() will make random number generation
 # in the TensorFlow backend have a well-defined initial state.
 # For further details, see:
 # https://www.tensorflow.org/api_docs/python/tf/set_random_seed
-tf.set_random_seed(1234)
-# sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
+tf_random_seed = 1234
+tf.set_random_seed(tf_random_seed)
+# sess = tf.Session(graph=tf.get_default_graph(), config=session_config)
 # keras.backend.set_session(sess)
+
+run_config = tf.estimator.RunConfig()
+run_config = run_config.replace(session_config=session_config)
+run_config = run_config.replace(tf_random_seed=tf_random_seed)
diff --git a/setup.py b/setup.py
index 6a0b4dbdcf9cca07279ef5a0157348188e85d9cf..6d5285796dc871e6b7bab35b8bc11af3303c5b87 100755
--- a/setup.py
+++ b/setup.py
@@ -47,13 +47,13 @@ setup(
 
         # scripts should be declared using this entry:
         'console_scripts': [
-            'compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main',
-            'train.py = bob.learn.tensorflow.script.train:main',
-            'bob_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main',
-            'load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main',
-            'lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main',
+            'bob_tf_compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main',
+            'bob_tf_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main',
+            'bob_tf_load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main',
+            'bob_tf_lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main',
             'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main',
             'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main',
+            'bob_tf_predict_generic = bob.learn.tensorflow.script.predict_generic:main',
         ],
 
     },