diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index 84c21ab9f822bfe5311a7f1c2a44e7924daee048..faa0938a23db82957339149dcf914ce270a1208e 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -291,4 +291,55 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
     features['key'] = key
     
     return features, labels
+    
+    
+def batch_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type,
+                                             batch_size, epochs=1,
+                                             gray_scale=False,
+                                             output_shape=None,
+                                             random_flip=False,
+                                             random_brightness=False,
+                                             random_contrast=False,
+                                             random_saturation=False,
+                                             per_image_normalization=True):
+    """
+    Dump in order batches from a list of tf-record files
+
+    **Parameters**
+
+       tfrecord_filenames:
+          List containing the tf-record paths
+
+       data_shape:
+          Samples shape saved in the tf-record
+
+       data_type:
+          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
+
+       batch_size:
+          Size of the batch
+
+       epochs:
+           Number of epochs to be batched
+
+    """
+                                          
+    dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape,
+                                                            data_type,
+                                                            gray_scale=gray_scale,
+                                                            output_shape=output_shape,
+                                                            random_flip=random_flip,
+                                                            random_brightness=random_brightness,
+                                                            random_contrast=random_contrast,
+                                                            random_saturation=random_saturation,
+                                                            per_image_normalization=per_image_normalization)
+                                          
+    dataset = dataset.batch(batch_size).repeat(epochs)
+
+    data, labels, key = dataset.make_one_shot_iterator().get_next()
+    features = dict()
+    features['data'] = data
+    features['key'] = key
+    
+    return features, labels