diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py
index e4dee7a0fadad28f14f9b0cc5bb3b3598c3dbe26..5f904d6aa216c073024dc2cb6b9a0adb162a4335 100644
--- a/bob/learn/tensorflow/dataset/__init__.py
+++ b/bob/learn/tensorflow/dataset/__init__.py
@@ -4,17 +4,18 @@ import os
 import bob.io.base
 
 DEFAULT_FEATURE = {
-    'data': tf.FixedLenFeature([], tf.string),
-    'label': tf.FixedLenFeature([], tf.int64),
-    'key': tf.FixedLenFeature([], tf.string)
+    "data": tf.FixedLenFeature([], tf.string),
+    "label": tf.FixedLenFeature([], tf.int64),
+    "key": tf.FixedLenFeature([], tf.string),
 }
 
 
 def from_hdf5file_to_tensor(filename):
     import bob.io.image
+
     data = bob.io.image.to_matplotlib(bob.io.base.load(filename))
 
-    #reshaping to ndim == 3
+    # reshaping to ndim == 3
     if data.ndim == 2:
         data = numpy.reshape(data, (data.shape[0], data.shape[1], 1))
     data = data.astype("float32")
@@ -25,7 +26,7 @@ def from_hdf5file_to_tensor(filename):
 def from_filename_to_tensor(filename, extension=None):
     """
     Read a file and it convert it to tensor.
-    
+
     If the file extension is something that tensorflow understands (.jpg, .bmp, .tif,...),
     it uses the `tf.image.decode_image` otherwise it uses `bob.io.base.load`
     """
@@ -33,19 +34,22 @@ def from_filename_to_tensor(filename, extension=None):
     if extension == "hdf5":
         return tf.py_func(from_hdf5file_to_tensor, [filename], [tf.float32])
     else:
-        return tf.cast(
-            tf.image.decode_image(tf.read_file(filename)), tf.float32)
-
-
-def append_image_augmentation(image,
-                              gray_scale=False,
-                              output_shape=None,
-                              random_flip=False,
-                              random_brightness=False,
-                              random_contrast=False,
-                              random_saturation=False,
-                              random_rotate=False,
-                              per_image_normalization=True):
+        return tf.cast(tf.image.decode_image(tf.read_file(filename)), tf.float32)
+
+
+def append_image_augmentation(
+    image,
+    gray_scale=False,
+    output_shape=None,
+    random_flip=False,
+    random_brightness=False,
+    random_contrast=False,
+    random_saturation=False,
+    random_rotate=False,
+    per_image_normalization=True,
+    random_gamma=False,
+    random_crop=False,
+):
     """
     Append to the current tensor some random image augmentation operation
 
@@ -76,37 +80,43 @@ def append_image_augmentation(image,
 
     """
 
-    # Casting to float32
-    image = tf.cast(image, tf.float32)
+    # Changing the range from 0-255 to 0-1
+    image = tf.cast(image, tf.float32) / 255
     # FORCING A SEED FOR THE RANDOM OPERATIONS
     tf.set_random_seed(0)
 
     if output_shape is not None:
-        assert len(output_shape) == 2
-        image = tf.image.resize_image_with_crop_or_pad(image, output_shape[0],
-                                                       output_shape[1])
+        if random_crop:
+            image = tf.random_crop(image, size=list(output_shape) + [3])
+        else:
+            assert len(output_shape) == 2
+            image = tf.image.resize_image_with_crop_or_pad(
+                image, output_shape[0], output_shape[1]
+            )
 
     if random_flip:
         image = tf.image.random_flip_left_right(image)
 
     if random_brightness:
-        image = tf.image.random_brightness(image, max_delta=0.5)
+        image = tf.image.random_brightness(image, max_delta=0.15)
+        image = tf.clip_by_value(image, 0, 1)
 
     if random_contrast:
-        image = tf.image.random_contrast(image, lower=0, upper=0.5)
+        image = tf.image.random_contrast(image, lower=0.85, upper=1.15)
+        image = tf.clip_by_value(image, 0, 1)
 
     if random_saturation:
-        image = tf.image.random_saturation(image, lower=0, upper=0.5)
+        image = tf.image.random_saturation(image, lower=0.85, upper=1.15)
+        image = tf.clip_by_value(image, 0, 1)
 
-    if random_rotate:
-        image = tf.contrib.image.rotate(
-            image,
-            angles=numpy.random.randint(-5, 5),
-            interpolation="BILINEAR")
+    if random_gamma:
+        image = tf.image.adjust_gamma(
+            image, gamma=tf.random.uniform(shape=[], minval=0.85, maxval=1.15)
+        )
+        image = tf.clip_by_value(image, 0, 1)
 
     if gray_scale:
         image = tf.image.rgb_to_grayscale(image, name="rgb_to_gray")
-        #self.output_shape[3] = 1
 
     # normalizing data
     if per_image_normalization:
@@ -153,20 +163,29 @@ def triplets_random_generator(input_data, input_labels):
     input_labels = numpy.array(input_labels)
     total_samples = input_data.shape[0]
 
-    indexes_per_labels = arrange_indexes_by_label(input_labels,
-                                                  possible_labels)
+    indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
 
     # searching for random triplets
     offset_class = 0
     for i in range(total_samples):
 
-        anchor_sample = input_data[indexes_per_labels[possible_labels[
-            offset_class]][numpy.random.randint(
-                len(indexes_per_labels[possible_labels[offset_class]]))], ...]
-
-        positive_sample = input_data[indexes_per_labels[possible_labels[
-            offset_class]][numpy.random.randint(
-                len(indexes_per_labels[possible_labels[offset_class]]))], ...]
+        anchor_sample = input_data[
+            indexes_per_labels[possible_labels[offset_class]][
+                numpy.random.randint(
+                    len(indexes_per_labels[possible_labels[offset_class]])
+                )
+            ],
+            ...,
+        ]
+
+        positive_sample = input_data[
+            indexes_per_labels[possible_labels[offset_class]][
+                numpy.random.randint(
+                    len(indexes_per_labels[possible_labels[offset_class]])
+                )
+            ],
+            ...,
+        ]
 
         # Changing the class
         offset_class += 1
@@ -174,9 +193,14 @@ def triplets_random_generator(input_data, input_labels):
         if offset_class == len(possible_labels):
             offset_class = 0
 
-        negative_sample = input_data[indexes_per_labels[possible_labels[
-            offset_class]][numpy.random.randint(
-                len(indexes_per_labels[possible_labels[offset_class]]))], ...]
+        negative_sample = input_data[
+            indexes_per_labels[possible_labels[offset_class]][
+                numpy.random.randint(
+                    len(indexes_per_labels[possible_labels[offset_class]])
+                )
+            ],
+            ...,
+        ]
 
         append(str(anchor_sample), str(positive_sample), str(negative_sample))
         # yield anchor, positive, negative
@@ -214,17 +238,18 @@ def siamease_pairs_generator(input_data, input_labels):
     total_samples = input_data.shape[0]
 
     # Filtering the samples by label and shuffling all the indexes
-    #indexes_per_labels = dict()
+    # indexes_per_labels = dict()
     # for l in possible_labels:
     #    indexes_per_labels[l] = numpy.where(input_labels == l)[0]
     #    numpy.random.shuffle(indexes_per_labels[l])
-    indexes_per_labels = arrange_indexes_by_label(input_labels,
-                                                  possible_labels)
+    indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
 
     left_possible_indexes = numpy.random.choice(
-        possible_labels, total_samples, replace=True)
+        possible_labels, total_samples, replace=True
+    )
     right_possible_indexes = numpy.random.choice(
-        possible_labels, total_samples, replace=True)
+        possible_labels, total_samples, replace=True
+    )
 
     genuine = True
     for i in range(total_samples):
@@ -234,10 +259,16 @@ def siamease_pairs_generator(input_data, input_labels):
             class_index = left_possible_indexes[i]
 
             # Now selecting the samples for the pair
-            left = input_data[indexes_per_labels[class_index][
-                numpy.random.randint(len(indexes_per_labels[class_index]))]]
-            right = input_data[indexes_per_labels[class_index][
-                numpy.random.randint(len(indexes_per_labels[class_index]))]]
+            left = input_data[
+                indexes_per_labels[class_index][
+                    numpy.random.randint(len(indexes_per_labels[class_index]))
+                ]
+            ]
+            right = input_data[
+                indexes_per_labels[class_index][
+                    numpy.random.randint(len(indexes_per_labels[class_index]))
+                ]
+            ]
             append(left, right, 0)
             # yield left, right, 0
         else:
@@ -248,7 +279,9 @@ def siamease_pairs_generator(input_data, input_labels):
             # Finding the right pair
             j = i
             # TODO: Lame solution. Fix this
-            while j < total_samples:  # Here is an unidiretinal search for the negative pair
+            while (
+                j < total_samples
+            ):  # Here is an unidiretinal search for the negative pair
                 if left_possible_indexes[i] != right_possible_indexes[j]:
                     class_index.append(right_possible_indexes[j])
                     break
@@ -256,12 +289,16 @@ def siamease_pairs_generator(input_data, input_labels):
 
             if j < total_samples:
                 # Now selecting the samples for the pair
-                left = input_data[indexes_per_labels[class_index[0]][
-                    numpy.random.randint(
-                        len(indexes_per_labels[class_index[0]]))]]
-                right = input_data[indexes_per_labels[class_index[1]][
-                    numpy.random.randint(
-                        len(indexes_per_labels[class_index[1]]))]]
+                left = input_data[
+                    indexes_per_labels[class_index[0]][
+                        numpy.random.randint(len(indexes_per_labels[class_index[0]]))
+                    ]
+                ]
+                right = input_data[
+                    indexes_per_labels[class_index[1]][
+                        numpy.random.randint(len(indexes_per_labels[class_index[1]]))
+                    ]
+                ]
                 append(left, right, 1)
 
         genuine = not genuine
@@ -295,8 +332,9 @@ def blocks_tensorflow(images, block_size):
     # extract image patches for each color space:
     output = []
     for i in range(3):
-        blocks = tf.extract_image_patches(images[:, :, :, i:i + 1], block_size,
-                                          block_size, [1, 1, 1, 1], "VALID")
+        blocks = tf.extract_image_patches(
+            images[:, :, :, i : i + 1], block_size, block_size, [1, 1, 1, 1], "VALID"
+        )
         if i == 0:
             n_blocks = int(numpy.prod(blocks.shape[1:3]))
         blocks = tf.reshape(blocks, output_size)