diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index 15b25fbaf31033b186b26e44c62e04a5849a5607..7d0f5c15249c2c6bf9c6b0856b2ade1d17309f48 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -1,48 +1,195 @@
+"""Utilities for TFRecords
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
 from functools import partial
+import json
+import logging
+import os
+import sys
+
 import tensorflow as tf
+
 from . import append_image_augmentation, DEFAULT_FEATURE
-import os
-import logging
+
+
 logger = logging.getLogger(__name__)
+TFRECORDS_EXT = ".tfrecords"
 
+def tfrecord_name_and_json_name(output):
+    output = normalize_tfrecords_path(output)
+    json_output = output[: -len(TFRECORDS_EXT)] + ".json"
+    return output, json_output
 
-def example_parser(serialized_example, feature, data_shape, data_type):
+
+def normalize_tfrecords_path(output):
+    if not output.endswith(TFRECORDS_EXT):
+        output += TFRECORDS_EXT
+    return output
+
+
+def bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def dataset_to_tfrecord(dataset, output):
+    """Writes a tf.data.Dataset into a TFRecord file.
+
+    Parameters
+    ----------
+    dataset : tf.data.Dataset
+        The tf.data.Dataset that you want to write into a TFRecord file.
+    output : str
+        Path to the TFRecord file. Besides this file, a .json file is also created.
+        This json file is needed when you want to convert the TFRecord file back into
+        a dataset.
+
+    Returns
+    -------
+    tf.Operation
+        A tf.Operation that, when run, writes contents of dataset to a file. When
+        running in eager mode, calling this function will write the file. Otherwise, you
+        have to call session.run() on the returned operation.
     """
-    Parses a single tf.Example into image and label tensors.
+    output, json_output = tfrecord_name_and_json_name(output)
+    # dump the structure so that we can read it back
+    meta = {
+        "output_types": repr(dataset.output_types),
+        "output_shapes": repr(dataset.output_shapes),
+    }
+    with open(json_output, "w") as f:
+        json.dump(meta, f)
+
+    # create a custom map function that serializes the dataset
+    def serialize_example_pyfunction(*args):
+        feature = {}
+        for i, f in enumerate(args):
+            key = f"feature{i}"
+            feature[key] = bytes_feature(f)
+        example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
+        return example_proto.SerializeToString()
+
+    def tf_serialize_example(*args):
+        args = tf.contrib.framework.nest.flatten(args)
+        args = [tf.serialize_tensor(f) for f in args]
+        tf_string = tf.py_func(serialize_example_pyfunction, args, tf.string)
+        return tf.reshape(tf_string, ())  # The result is a scalar
+
+    dataset = dataset.map(tf_serialize_example)
+    writer = tf.data.experimental.TFRecordWriter(output)
+    return writer.write(dataset)
+
+
+def dataset_from_tfrecord(tfrecord):
+    """Reads TFRecords and returns a dataset.
+    The TFRecord file must have been created using the :any:`dataset_to_tfrecord`
+    function.
 
+    Parameters
+    ----------
+    tfrecord : str or list
+        Path to the TFRecord file. Pass a list if you are sure several tfrecords need
+        the same map function.
+
+    Returns
+    -------
+    tf.data.Dataset
+        A dataset that contains the data from the TFRecord file.
     """
+    # these imports are needed so that eval can work
+    from tensorflow import TensorShape, Dimension
+
+    if isinstance(tfrecord, str):
+        tfrecord = [tfrecord]
+    tfrecord = [tfrecord_name_and_json_name(path) for path in tfrecord]
+    json_output = tfrecord[0][1]
+    tfrecord = [path[0] for path in tfrecord]
+    raw_dataset = tf.data.TFRecordDataset(tfrecord)
+
+    with open(json_output) as f:
+        meta = json.load(f)
+    for k, v in meta.items():
+        meta[k] = eval(v)
+    output_types = tf.contrib.framework.nest.flatten(meta["output_types"])
+    output_shapes = tf.contrib.framework.nest.flatten(meta["output_shapes"])
+    feature_description = {}
+    for i in range(len(output_types)):
+        key = f"feature{i}"
+        feature_description[key] = tf.FixedLenFeature([], tf.string)
+
+    def _parse_function(example_proto):
+        # Parse the input tf.Example proto using the dictionary above.
+        args = tf.parse_single_example(example_proto, feature_description)
+        args = tf.contrib.framework.nest.flatten(args)
+        args = [tf.parse_tensor(v, t) for v, t in zip(args, output_types)]
+        args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)]
+        return tf.contrib.framework.nest.pack_sequence_as(meta["output_types"], args)
+
+    return raw_dataset.map(_parse_function)
+
+
+def write_a_sample(writer, data, label, key, feature=None, size_estimate=False):
+    if feature is None:
+        feature = {
+            "data": bytes_feature(data.tostring()),
+            "label": int64_feature(label),
+            "key": bytes_feature(key),
+        }
+
+    example = tf.train.Example(features=tf.train.Features(feature=feature))
+    example = example.SerializeToString()
+    if not size_estimate:
+        writer.write(example)
+    return sys.getsizeof(example)
+
+
+def example_parser(serialized_example, feature, data_shape, data_type):
+    """
+  Parses a single tf.Example into image and label tensors.
+
+  """
     # Decode the record read by the reader
     features = tf.parse_single_example(serialized_example, features=feature)
     # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features['data'], data_type)
+    image = tf.decode_raw(features["data"], data_type)
     # Cast label data into int64
-    label = tf.cast(features['label'], tf.int64)
+    label = tf.cast(features["label"], tf.int64)
     # Reshape image data into the original shape
     image = tf.reshape(image, data_shape)
-    key = tf.cast(features['key'], tf.string)
+    key = tf.cast(features["key"], tf.string)
     return image, label, key
 
 
-def image_augmentation_parser(serialized_example,
-                              feature,
-                              data_shape,
-                              data_type,
-                              gray_scale=False,
-                              output_shape=None,
-                              random_flip=False,
-                              random_brightness=False,
-                              random_contrast=False,
-                              random_saturation=False,
-                              random_rotate=False,
-                              per_image_normalization=True):
+def image_augmentation_parser(
+    serialized_example,
+    feature,
+    data_shape,
+    data_type,
+    gray_scale=False,
+    output_shape=None,
+    random_flip=False,
+    random_brightness=False,
+    random_contrast=False,
+    random_saturation=False,
+    random_rotate=False,
+    per_image_normalization=True,
+    random_gamma=False,
+    random_crop=False,
+):
     """
-    Parses a single tf.Example into image and label tensors.
+  Parses a single tf.Example into image and label tensors.
 
-    """
+  """
     # Decode the record read by the reader
     features = tf.parse_single_example(serialized_example, features=feature)
     # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features['data'], data_type)
+    image = tf.decode_raw(features["data"], data_type)
 
     # Reshape image data into the original shape
     image = tf.reshape(image, data_shape)
@@ -57,23 +204,23 @@ def image_augmentation_parser(serialized_example,
         random_contrast=random_contrast,
         random_saturation=random_saturation,
         random_rotate=random_rotate,
-        per_image_normalization=per_image_normalization)
+        per_image_normalization=per_image_normalization,
+        random_gamma=random_gamma,
+        random_crop=random_crop,
+    )
 
     # Cast label data into int64
-    label = tf.cast(features['label'], tf.int64)
-    key = tf.cast(features['key'], tf.string)
+    label = tf.cast(features["label"], tf.int64)
+    key = tf.cast(features["key"], tf.string)
 
     return image, label, key
 
 
-def read_and_decode(filename_queue,
-                    data_shape,
-                    data_type=tf.float32,
-                    feature=None):
-    """
-    Simples parse possible for a tfrecord.
-    It assumes that you have the pair **train/data** and **train/label**
+def read_and_decode(filename_queue, data_shape, data_type=tf.float32, feature=None):
     """
+  Simples parse possible for a tfrecord.
+  It assumes that you have the pair **train/data** and **train/label**
+  """
 
     if feature is None:
         feature = DEFAULT_FEATURE
@@ -83,73 +230,77 @@ def read_and_decode(filename_queue,
     return example_parser(serialized_example, feature, data_shape, data_type)
 
 
-def create_dataset_from_records(tfrecord_filenames,
-                                data_shape,
-                                data_type,
-                                feature=None):
+def create_dataset_from_records(
+    tfrecord_filenames, data_shape, data_type, feature=None
+):
     """
-    Create dataset from a list of tf-record files
+  Create dataset from a list of tf-record files
 
-    **Parameters**
+  **Parameters**
 
-       tfrecord_filenames:
-          List containing the tf-record paths
+     tfrecord_filenames:
+        List containing the tf-record paths
 
-       data_shape:
-          Samples shape saved in the tf-record
+     data_shape:
+        Samples shape saved in the tf-record
 
-       data_type:
-          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     data_type:
+        tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
 
-       feature:
+     feature:
 
-    """
+  """
 
     if feature is None:
         feature = DEFAULT_FEATURE
     dataset = tf.data.TFRecordDataset(tfrecord_filenames)
     parser = partial(
-        example_parser,
-        feature=feature,
-        data_shape=data_shape,
-        data_type=data_type)
+        example_parser, feature=feature, data_shape=data_shape, data_type=data_type
+    )
     dataset = dataset.map(parser)
     return dataset
 
 
 def create_dataset_from_records_with_augmentation(
-        tfrecord_filenames,
-        data_shape,
-        data_type,
-        feature=None,
-        gray_scale=False,
-        output_shape=None,
-        random_flip=False,
-        random_brightness=False,
-        random_contrast=False,
-        random_saturation=False,
-        random_rotate=False,
-        per_image_normalization=True):
+    tfrecord_filenames,
+    data_shape,
+    data_type,
+    feature=None,
+    gray_scale=False,
+    output_shape=None,
+    random_flip=False,
+    random_brightness=False,
+    random_contrast=False,
+    random_saturation=False,
+    random_rotate=False,
+    per_image_normalization=True,
+    random_gamma=False,
+    random_crop=False,
+):
     """
-    Create dataset from a list of tf-record files
+  Create dataset from a list of tf-record files
 
-    **Parameters**
+  **Parameters**
 
-       tfrecord_filenames:
-          List containing the tf-record paths
+     tfrecord_filenames:
+        List containing the tf-record paths
 
-       data_shape:
-          Samples shape saved in the tf-record
+     data_shape:
+        Samples shape saved in the tf-record
 
-       data_type:
-          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     data_type:
+        tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
 
-       feature:
+     feature:
 
-    """
+  """
 
     if feature is None:
         feature = DEFAULT_FEATURE
+    if os.path.isdir(tfrecord_filenames):
+        tfrecord_filenames = [
+            os.path.join(tfrecord_filenames, f) for f in os.listdir(tfrecord_filenames)
+        ]
     dataset = tf.data.TFRecordDataset(tfrecord_filenames)
     parser = partial(
         image_augmentation_parser,
@@ -163,73 +314,80 @@ def create_dataset_from_records_with_augmentation(
         random_contrast=random_contrast,
         random_saturation=random_saturation,
         random_rotate=random_rotate,
-        per_image_normalization=per_image_normalization)
+        per_image_normalization=per_image_normalization,
+        random_gamma=random_gamma,
+        random_crop=random_crop,
+    )
     dataset = dataset.map(parser)
     return dataset
 
 
-def shuffle_data_and_labels_image_augmentation(tfrecord_filenames,
-                                               data_shape,
-                                               data_type,
-                                               batch_size,
-                                               epochs=None,
-                                               buffer_size=10**3,
-                                               gray_scale=False,
-                                               output_shape=None,
-                                               random_flip=False,
-                                               random_brightness=False,
-                                               random_contrast=False,
-                                               random_saturation=False,
-                                               random_rotate=False,
-                                               per_image_normalization=True):
+def shuffle_data_and_labels_image_augmentation(
+    tfrecord_filenames,
+    data_shape,
+    data_type,
+    batch_size,
+    epochs=None,
+    buffer_size=10 ** 3,
+    gray_scale=False,
+    output_shape=None,
+    random_flip=False,
+    random_brightness=False,
+    random_contrast=False,
+    random_saturation=False,
+    random_rotate=False,
+    per_image_normalization=True,
+    random_gamma=False,
+    random_crop=False,
+):
     """
-    Dump random batches from a list of tf-record files and applies some image augmentation
+  Dump random batches from a list of tf-record files and applies some image augmentation
 
-    **Parameters**
+  **Parameters**
 
-       tfrecord_filenames:
-          List containing the tf-record paths
+     tfrecord_filenames:
+        List containing the tf-record paths
 
-       data_shape:
-          Samples shape saved in the tf-record
+     data_shape:
+        Samples shape saved in the tf-record
 
-       data_type:
-          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     data_type:
+        tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
 
-       batch_size:
-          Size of the batch
+     batch_size:
+        Size of the batch
 
-       epochs:
-           Number of epochs to be batched
+     epochs:
+         Number of epochs to be batched
 
-       buffer_size:
-            Size of the shuffle bucket
+     buffer_size:
+          Size of the shuffle bucket
 
-       gray_scale:
-          Convert to gray scale?
+     gray_scale:
+        Convert to gray scale?
 
-       output_shape:
-          If set, will randomly crop the image given the output shape
+     output_shape:
+        If set, will randomly crop the image given the output shape
 
-       random_flip:
-          Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
+     random_flip:
+        Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
 
-       random_brightness:
-           Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
+     random_brightness:
+         Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
 
-       random_contrast:
-           Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
+     random_contrast:
+         Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
 
-       random_saturation:
-           Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
+     random_saturation:
+         Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
 
-       random_rotate:
-           Randomly rotate face images between -5 and 5 degrees
+     random_rotate:
+         Randomly rotate face images between -5 and 5 degrees
 
-      per_image_normalization:
-           Linearly scales image to have zero mean and unit norm.
+    per_image_normalization:
+         Linearly scales image to have zero mean and unit norm.
 
-    """
+  """
 
     dataset = create_dataset_from_records_with_augmentation(
         tfrecord_filenames,
@@ -242,134 +400,135 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames,
         random_contrast=random_contrast,
         random_saturation=random_saturation,
         random_rotate=random_rotate,
-        per_image_normalization=per_image_normalization)
+        per_image_normalization=per_image_normalization,
+        random_gamma=random_gamma,
+        random_crop=random_crop,
+    )
 
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
 
-    data, labels, key = dataset.make_one_shot_iterator().get_next()
-
-    features = dict()
-    features['data'] = data
-    features['key'] = key
+    dataset = dataset.map(lambda d, l, k: ({"data": d, "key": k}, l))
 
-    return features, labels
+    return dataset
 
 
-def shuffle_data_and_labels(tfrecord_filenames,
-                            data_shape,
-                            data_type,
-                            batch_size,
-                            epochs=None,
-                            buffer_size=10**3):
+def shuffle_data_and_labels(
+    tfrecord_filenames,
+    data_shape,
+    data_type,
+    batch_size,
+    epochs=None,
+    buffer_size=10 ** 3,
+):
     """
-    Dump random batches from a list of tf-record files
+  Dump random batches from a list of tf-record files
 
-    **Parameters**
+  **Parameters**
 
-       tfrecord_filenames:
-          List containing the tf-record paths
+     tfrecord_filenames:
+        List containing the tf-record paths
 
-       data_shape:
-          Samples shape saved in the tf-record
+     data_shape:
+        Samples shape saved in the tf-record
 
-       data_type:
-          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     data_type:
+        tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
 
-       batch_size:
-          Size of the batch
+     batch_size:
+        Size of the batch
 
-       epochs:
-           Number of epochs to be batched
+     epochs:
+         Number of epochs to be batched
 
-       buffer_size:
-            Size of the shuffle bucket
+     buffer_size:
+          Size of the shuffle bucket
 
-    """
+  """
 
-    dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
-                                          data_type)
+    dataset = create_dataset_from_records(tfrecord_filenames, data_shape, data_type)
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
 
     data, labels, key = dataset.make_one_shot_iterator().get_next()
     features = dict()
-    features['data'] = data
-    features['key'] = key
+    features["data"] = data
+    features["key"] = key
 
     return features, labels
 
 
-def batch_data_and_labels(tfrecord_filenames,
-                          data_shape,
-                          data_type,
-                          batch_size,
-                          epochs=1):
+def batch_data_and_labels(
+    tfrecord_filenames, data_shape, data_type, batch_size, epochs=1
+):
     """
-    Dump in order batches from a list of tf-record files
+  Dump in order batches from a list of tf-record files
 
-    **Parameters**
+  **Parameters**
 
-       tfrecord_filenames:
-          List containing the tf-record paths
+     tfrecord_filenames:
+        List containing the tf-record paths
 
-       data_shape:
-          Samples shape saved in the tf-record
+     data_shape:
+        Samples shape saved in the tf-record
 
-       data_type:
-          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     data_type:
+        tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
 
-       batch_size:
-          Size of the batch
+     batch_size:
+        Size of the batch
 
-       epochs:
-           Number of epochs to be batched
+     epochs:
+         Number of epochs to be batched
 
-    """
-    dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
-                                          data_type)
+  """
+    dataset = create_dataset_from_records(tfrecord_filenames, data_shape, data_type)
     dataset = dataset.batch(batch_size).repeat(epochs)
 
     data, labels, key = dataset.make_one_shot_iterator().get_next()
     features = dict()
-    features['data'] = data
-    features['key'] = key
+    features["data"] = data
+    features["key"] = key
 
     return features, labels
 
 
-def batch_data_and_labels_image_augmentation(tfrecord_filenames,
-                                             data_shape,
-                                             data_type,
-                                             batch_size,
-                                             epochs=1,
-                                             gray_scale=False,
-                                             output_shape=None,
-                                             random_flip=False,
-                                             random_brightness=False,
-                                             random_contrast=False,
-                                             random_saturation=False,
-                                             random_rotate=False,
-                                             per_image_normalization=True):
+def batch_data_and_labels_image_augmentation(
+    tfrecord_filenames,
+    data_shape,
+    data_type,
+    batch_size,
+    epochs=1,
+    gray_scale=False,
+    output_shape=None,
+    random_flip=False,
+    random_brightness=False,
+    random_contrast=False,
+    random_saturation=False,
+    random_rotate=False,
+    per_image_normalization=True,
+    random_gamma=False,
+    random_crop=False,
+):
     """
-    Dump in order batches from a list of tf-record files
+  Dump in order batches from a list of tf-record files
 
-    **Parameters**
+  **Parameters**
 
-       tfrecord_filenames:
-          List containing the tf-record paths
+     tfrecord_filenames:
+        List containing the tf-record paths
 
-       data_shape:
-          Samples shape saved in the tf-record
+     data_shape:
+        Samples shape saved in the tf-record
 
-       data_type:
-          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+     data_type:
+        tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
 
-       batch_size:
-          Size of the batch
+     batch_size:
+        Size of the batch
 
-       epochs:
-           Number of epochs to be batched
+     epochs:
+         Number of epochs to be batched
 
-    """
+  """
 
     dataset = create_dataset_from_records_with_augmentation(
         tfrecord_filenames,
@@ -382,54 +541,59 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames,
         random_contrast=random_contrast,
         random_saturation=random_saturation,
         random_rotate=random_rotate,
-        per_image_normalization=per_image_normalization)
+        per_image_normalization=per_image_normalization,
+        random_gamma=random_gamma,
+        random_crop=random_crop,
+    )
 
     dataset = dataset.batch(batch_size).repeat(epochs)
 
     data, labels, key = dataset.make_one_shot_iterator().get_next()
     features = dict()
-    features['data'] = data
-    features['key'] = key
+    features["data"] = data
+    features["key"] = key
 
     return features, labels
 
 
 def describe_tf_record(tf_record_path, shape, batch_size=1):
     """
-    Describe the number of samples and the number of classes of a tf-record
+  Describe the number of samples and the number of classes of a tf-record
 
-    Parameters
-    ----------
+  Parameters
+  ----------
 
-    tf_record_path: str
-      Base path containing your tf-record files
+  tf_record_path: str
+    Base path containing your tf-record files
 
-    shape: tuple
-       Shape inside of the tf-record
+  shape: tuple
+     Shape inside of the tf-record
 
-    batch_size: int
-      Well, batch size
+  batch_size: int
+    Well, batch size
 
 
-    Returns
-    -------
+  Returns
+  -------
 
-    n_samples: int
-       Total number of samples
+  n_samples: int
+     Total number of samples
 
-    n_classes: int
-       Total number of classes
-
-    """
+  n_classes: int
+     Total number of classes
 
+  """
 
     tf_records = [os.path.join(tf_record_path, f) for f in os.listdir(tf_record_path)]
-    filename_queue = tf.train.string_input_producer(tf_records, num_epochs=1, name="input")
+    filename_queue = tf.train.string_input_producer(
+        tf_records, num_epochs=1, name="input"
+    )
 
-    feature = {'data': tf.FixedLenFeature([], tf.string),
-               'label': tf.FixedLenFeature([], tf.int64),
-               'key': tf.FixedLenFeature([], tf.string)
-               }
+    feature = {
+        "data": tf.FixedLenFeature([], tf.string),
+        "label": tf.FixedLenFeature([], tf.int64),
+        "key": tf.FixedLenFeature([], tf.string),
+    }
 
     # Define a reader and read the next record
     reader = tf.TFRecordReader()
@@ -440,18 +604,23 @@ def describe_tf_record(tf_record_path, shape, batch_size=1):
     features = tf.parse_single_example(serialized_example, features=feature)
 
     # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features['data'], tf.uint8)
+    image = tf.decode_raw(features["data"], tf.uint8)
 
     # Cast label data into int32
-    label = tf.cast(features['label'], tf.int64)
-    img_name = tf.cast(features['key'], tf.string)
+    label = tf.cast(features["label"], tf.int64)
+    img_name = tf.cast(features["key"], tf.string)
 
     # Reshape image data into the original shape
     image = tf.reshape(image, shape)
 
     # Getting the batches in order
-    data_ph, label_ph, img_name_ph = tf.train.batch([image, label, img_name], batch_size=batch_size,
-                     capacity=1000, num_threads=5, name="shuffle_batch")
+    data_ph, label_ph, img_name_ph = tf.train.batch(
+        [image, label, img_name],
+        batch_size=batch_size,
+        capacity=1000,
+        num_threads=5,
+        name="shuffle_batch",
+    )
 
     # Start the reading
     session = tf.Session()
@@ -462,12 +631,11 @@ def describe_tf_record(tf_record_path, shape, batch_size=1):
     thread_pool = tf.train.Coordinator()
     threads = tf.train.start_queue_runners(coord=thread_pool, sess=session)
 
-
     logger.info("Counting in %s", tf_record_path)
     labels = set()
     counter = 0
     try:
-        while(True):
+        while True:
             _, label, _ = session.run([data_ph, label_ph, img_name_ph])
             counter += len(label)
 
@@ -479,4 +647,3 @@ def describe_tf_record(tf_record_path, shape, batch_size=1):
 
     thread_pool.request_stop()
     return counter, len(labels)
-
diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py
index 5d85dd4e0c31e39090e3fd68cc114ba9492df11f..fbef7a0013f1f6c62e2ca12b01050d771f103207 100644
--- a/bob/learn/tensorflow/script/db_to_tfrecords.py
+++ b/bob/learn/tensorflow/script/db_to_tfrecords.py
@@ -4,126 +4,90 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
+import logging
+import os
 import random
 import tempfile
-import os
-import sys
-import logging
 import click
 import tensorflow as tf
-from bob.io.base import create_directories_safe
+from bob.io.base import create_directories_safe, HDF5File
 from bob.extension.scripts.click_helper import (
-    verbosity_option, ConfigCommand, ResourceOption, log_parameters)
-import numpy
-from bob.learn.tensorflow.dataset.tfrecords import describe_tf_record
+    verbosity_option,
+    ConfigCommand,
+    ResourceOption,
+    log_parameters,
+)
+from bob.learn.tensorflow.dataset.tfrecords import (
+    describe_tf_record,
+    write_a_sample,
+    normalize_tfrecords_path,
+    tfrecord_name_and_json_name,
+    dataset_to_tfrecord,
+)
+from bob.learn.tensorflow.utils import bytes2human
 
 
 logger = logging.getLogger(__name__)
 
 
-def bytes_feature(value):
-    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
-
-
-def int64_feature(value):
-    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
-
-
-def write_a_sample(writer, data, label, key, feature=None,
-                   size_estimate=False):
-    if feature is None:
-        feature = {
-            'data': bytes_feature(data.tostring()),
-            'label': int64_feature(label),
-            'key': bytes_feature(key)
-        }
-
-    example = tf.train.Example(features=tf.train.Features(feature=feature))
-    example = example.SerializeToString()
-    if not size_estimate:
-        writer.write(example)
-    return sys.getsizeof(example)
-
-
-def _bytes2human(n, format='%(value).1f %(symbol)s', symbols='customary'):
-    """Convert n bytes into a human readable string based on format.
-    From: https://code.activestate.com/recipes/578019-bytes-to-human-human-to-
-    bytes-converter/
-    Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>
-    License: MIT
-    symbols can be either "customary", "customary_ext", "iec" or "iec_ext",
-    see: http://goo.gl/kTQMs
-    """
-    SYMBOLS = {
-        'customary': ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'),
-        'customary_ext': ('byte', 'kilo', 'mega', 'giga', 'tera', 'peta',
-                          'exa', 'zetta', 'iotta'),
-        'iec': ('Bi', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi'),
-        'iec_ext': ('byte', 'kibi', 'mebi', 'gibi', 'tebi', 'pebi', 'exbi',
-                    'zebi', 'yobi'),
-    }
-    n = int(n)
-    if n < 0:
-        raise ValueError("n < 0")
-    symbols = SYMBOLS[symbols]
-    prefix = {}
-    for i, s in enumerate(symbols[1:]):
-        prefix[s] = 1 << (i + 1) * 10
-    for symbol in reversed(symbols[1:]):
-        if n >= prefix[symbol]:
-            value = float(n) / prefix[symbol]
-            return format % locals()
-    return format % dict(symbol=symbols[0], value=n)
-
-
-@click.command(
-    entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
+@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
 @click.option(
-    '--samples',
+    "--samples",
     required=True,
     cls=ResourceOption,
-    help='A list of all samples that you want to write in the '
-    'tfrecords file. Whatever is inside this list is passed to '
-    'the reader.')
+    help="A list of all samples that you want to write in the "
+    "tfrecords file. Whatever is inside this list is passed to "
+    "the reader.",
+)
 @click.option(
-    '--reader',
+    "--reader",
     required=True,
     cls=ResourceOption,
-    help='a function with the signature of ``data, label, key = '
-    'reader(sample)`` which takes a sample and returns the '
-    'loaded data, the label of the data, and a key which is '
-    'unique for every sample.')
+    help="a function with the signature of ``data, label, key = "
+    "reader(sample)`` which takes a sample and returns the "
+    "loaded data, the label of the data, and a key which is "
+    "unique for every sample.",
+)
 @click.option(
-    '--output',
-    '-o',
-    required=True,
-    cls=ResourceOption,
-    help='Name of the output file.')
+    "--output", "-o", required=True, cls=ResourceOption, help="Name of the output file."
+)
 @click.option(
-    '--shuffle',
+    "--shuffle",
     is_flag=True,
     cls=ResourceOption,
-    help='If provided, it will shuffle the samples.')
+    help="If provided, it will shuffle the samples.",
+)
 @click.option(
-    '--allow-failures',
+    "--allow-failures",
     is_flag=True,
     cls=ResourceOption,
-    help='If provided, the samples which fail to load are ignored.')
+    help="If provided, the samples which fail to load are ignored.",
+)
 @click.option(
-    '--multiple-samples',
+    "--multiple-samples",
     is_flag=True,
     cls=ResourceOption,
-    help='If provided, it means that the data provided by reader contains '
-    'multiple samples with same label and path.')
+    help="If provided, it means that the data provided by reader contains "
+    "multiple samples with same label and path.",
+)
 @click.option(
-    '--size-estimate',
+    "--size-estimate",
     is_flag=True,
     cls=ResourceOption,
-    help='If given, will print the estimated file size instead of creating '
-    'the final tfrecord file.')
+    help="If given, will print the estimated file size instead of creating "
+    "the final tfrecord file.",
+)
 @verbosity_option(cls=ResourceOption)
-def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
-                    multiple_samples, size_estimate, **kwargs):
+def db_to_tfrecords(
+    samples,
+    reader,
+    output,
+    shuffle,
+    allow_failures,
+    multiple_samples,
+    size_estimate,
+    **kwargs,
+):
     """Converts Bio and PAD datasets to TFRecords file formats.
 
     The best way to use this script is to send it to the io-big queue if you
@@ -173,14 +137,13 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
             key = biofile.path
             return (data, label, key)
     """
-    log_parameters(logger, ignore=('samples', ))
+    log_parameters(logger, ignore=("samples",))
     logger.debug("len(samples): %d", len(samples))
 
     if size_estimate:
-        output = tempfile.NamedTemporaryFile(suffix='.tfrecords').name
+        output = tempfile.NamedTemporaryFile(suffix=".tfrecords").name
 
-    if not output.endswith(".tfrecords"):
-        output += ".tfrecords"
+    output = normalize_tfrecords_path(output)
 
     if not size_estimate:
         logger.info("Writing samples to `{}'".format(output))
@@ -196,7 +159,7 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
             logger.info("Shuffling the samples before writing ...")
             random.shuffle(samples)
         for i, sample in enumerate(samples):
-            logger.info('Processing file %d out of %d', i + 1, n_samples)
+            logger.info("Processing file %d out of %d", i + 1, n_samples)
 
             data, label, key = reader(sample)
 
@@ -205,55 +168,43 @@ def db_to_tfrecords(samples, reader, output, shuffle, allow_failures,
                     logger.debug("... Skipping `{0}`.".format(sample))
                     continue
                 else:
-                    raise RuntimeError(
-                        "Reading failed for `{0}`".format(sample))
+                    raise RuntimeError("Reading failed for `{0}`".format(sample))
 
             if multiple_samples:
                 for sample in data:
                     total_size += write_a_sample(
-                        writer,
-                        sample,
-                        label,
-                        key,
-                        size_estimate=size_estimate)
+                        writer, sample, label, key, size_estimate=size_estimate
+                    )
                     sample_count += 1
             else:
                 total_size += write_a_sample(
-                    writer, data, label, key, size_estimate=size_estimate)
+                    writer, data, label, key, size_estimate=size_estimate
+                )
                 sample_count += 1
 
     if not size_estimate:
-        click.echo(
-            "Wrote {} samples into the tfrecords file.".format(sample_count))
+        click.echo("Wrote {} samples into the tfrecords file.".format(sample_count))
     else:
         # delete the empty tfrecords file
         try:
             os.remove(output)
         except Exception:
             pass
-    click.echo("The total size of the tfrecords file will be roughly "
-               "{} bytes".format(_bytes2human(total_size)))
+    click.echo(
+        "The total size of the tfrecords file will be roughly "
+        "{} bytes".format(bytes2human(total_size))
+    )
 
 
 @click.command()
-@click.argument(
-    'tf-record-path',
-    nargs=1)
-@click.argument(
-    'shape',
-    type=int,
-    nargs=-1
-)
+@click.argument("tf-record-path", nargs=1)
+@click.argument("shape", type=int, nargs=-1)
 @click.option(
-    '--batch-size',
-    help='Batch size',
-    show_default=True,
-    required=True,
-    default=1000
+    "--batch-size", help="Batch size", show_default=True, required=True, default=1000
 )
 @verbosity_option(cls=ResourceOption)
 def describe_tfrecord(tf_record_path, shape, batch_size, **kwargs):
-    '''
+    """
     Very often you have a tf-record file, or a set of them, and you have no
     idea how many samples you have there. Even worse, you have no idea how many
     classes you have.
@@ -262,9 +213,58 @@ def describe_tfrecord(tf_record_path, shape, batch_size, **kwargs):
 
         $ %(prog)s <tf-record-path> 182 182 3
 
-    '''
+    """
     n_samples, n_labels = describe_tf_record(tf_record_path, shape, batch_size)
     click.echo("#############################################")
     click.echo("Number of samples {0}".format(n_samples))
     click.echo("Number of labels {0}".format(n_labels))
     click.echo("#############################################")
+
+
+@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
+@click.option(
+    "--dataset",
+    required=True,
+    cls=ResourceOption,
+    entry_point_group="bob.learn.tensorflow.dataset",
+    help="A tf.data.Dataset to be used.",
+)
+@click.option(
+    "--output", "-o", required=True, cls=ResourceOption, help="Name of the output file."
+)
+@click.option(
+    "--force",
+    "-f",
+    is_flag=True,
+    cls=ResourceOption,
+    help="Whether to overwrite existing files.",
+)
+@verbosity_option(cls=ResourceOption)
+def datasets_to_tfrecords(dataset, output, force, **kwargs):
+    """Converts tensorflow datasets into TFRecords.
+    Takes a list of datasets and outputs and writes each dataset into its output.
+    ``datasets`` and ``outputs`` variables must be lists.
+    You can convert the written TFRecord files back to datasets using
+    :any:`bob.learn.tensorflow.dataset.tfrecords.dataset_from_tfrecord`.
+
+    To use this script with SGE, change your dataset and output based on the SGE_TASK_ID
+    environment variable in your config file.
+    """
+    log_parameters(logger)
+
+    output, json_output = tfrecord_name_and_json_name(output)
+    if not force and os.path.isfile(output):
+        click.echo("Output file already exists: {}".format(output))
+        return
+
+    click.echo("Writing tfrecod to: {}".format(output))
+    with tf.Session() as sess:
+        os.makedirs(os.path.dirname(output), exist_ok=True)
+        try:
+            sess.run(dataset_to_tfrecord(dataset, output))
+        except Exception:
+            click.echo("Something failed. Deleting unfinished files.")
+            os.remove(output)
+            os.remove(json_output)
+            raise
+    click.echo("Successfully wrote all files.")
diff --git a/bob/learn/tensorflow/test/data/db_to_tfrecords_config.py b/bob/learn/tensorflow/test/data/db_to_tfrecords_config.py
index 276018a7ae94c6592ceaa98dc858d2df45952e51..6fa669656cafc38cf2f17d57df7389be6baa519c 100644
--- a/bob/learn/tensorflow/test/data/db_to_tfrecords_config.py
+++ b/bob/learn/tensorflow/test/data/db_to_tfrecords_config.py
@@ -1,7 +1,8 @@
 from bob.bio.base.test.dummy.database import database
 from bob.bio.base.utils import read_original_data
+from bob.learn.tensorflow.dataset.generator import dataset_using_generator
 
-groups = ['dev']
+groups = ["dev"]
 
 samples = database.all_files(groups=groups)
 
@@ -15,8 +16,13 @@ def file_to_label(f):
 
 
 def reader(biofile):
-    data = read_original_data(biofile, database.original_directory,
-                              database.original_extension)
+    data = read_original_data(
+        biofile, database.original_directory, database.original_extension
+    )
     label = file_to_label(biofile)
     key = str(biofile.path).encode("utf-8")
     return (data, label, key)
+
+
+dataset = dataset_using_generator(samples, reader)
+datasets = [dataset]
diff --git a/bob/learn/tensorflow/test/test_db_to_tfrecords.py b/bob/learn/tensorflow/test/test_db_to_tfrecords.py
index db8172d30f373d3aa90f3c12cf1373205e46a9c3..2318ca591c01c8d276940b5cbfc1359bab9c9c0d 100644
--- a/bob/learn/tensorflow/test/test_db_to_tfrecords.py
+++ b/bob/learn/tensorflow/test/test_db_to_tfrecords.py
@@ -2,11 +2,16 @@ import os
 import shutil
 import pkg_resources
 import tempfile
+import tensorflow as tf
+import numpy as np
 from click.testing import CliRunner
 from bob.io.base import create_directories_safe
 from bob.learn.tensorflow.script.db_to_tfrecords import (
-    db_to_tfrecords, describe_tf_record)
+    db_to_tfrecords, describe_tf_record, datasets_to_tfrecords)
 from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
+from bob.extension.scripts.click_helper import assert_click_runner_result
+from bob.extension.config import load
+from bob.learn.tensorflow.dataset.tfrecords import dataset_from_tfrecord
 
 regenerate_reference = False
 
@@ -14,6 +19,31 @@ dummy_config = pkg_resources.resource_filename(
     'bob.learn.tensorflow', 'test/data/db_to_tfrecords_config.py')
 
 
+def compare_datasets(ds1, ds2, sess=None):
+    if tf.executing_eagerly():
+        for values1, values2 in zip(ds1, ds2):
+            values1 = tf.contrib.framework.nest.flatten(values1)
+            values2 = tf.contrib.framework.nest.flatten(values2)
+            for v1, v2 in zip(values1, values2):
+                if not tf.reduce_all(tf.math.equal(v1, v2)):
+                    return False
+    else:
+        ds1 = ds1.make_one_shot_iterator().get_next()
+        ds2 = ds2.make_one_shot_iterator().get_next()
+        while True:
+            try:
+                values1, values2 = sess.run([ds1, ds2])
+            except tf.errors.OutOfRangeError:
+                break
+            values1 = tf.contrib.framework.nest.flatten(values1)
+            values2 = tf.contrib.framework.nest.flatten(values2)
+            for v1, v2 in zip(values1, values2):
+                v1, v2 = np.asarray(v1), np.asarray(v2)
+                if not np.all(v1 == v2):
+                    return False
+    return True
+
+
 def test_db_to_tfrecords():
     test_dir = tempfile.mkdtemp(prefix='bobtest_')
     output_path = os.path.join(test_dir, 'dev.tfrecords')
@@ -71,3 +101,19 @@ def test_tfrecord_counter():
 
     finally:
         shutil.rmtree(os.path.dirname(tfrecord_train))
+
+
+def test_datasets_to_tfrecords():
+    runner = CliRunner()
+    with runner.isolated_filesystem():
+        output_path = './test'
+        args = (dummy_config, '--outputs', output_path)
+        result = runner.invoke(
+            datasets_to_tfrecords, args=args, standalone_mode=False)
+        assert_click_runner_result(result)
+        # read back the tfrecod
+        with tf.Session() as sess:
+            dataset2 = dataset_from_tfrecord(output_path)
+            dataset1 = load(
+                [dummy_config], attribute_name='dataset', entry_point_group='bob')
+            assert compare_datasets(dataset1, dataset2, sess)