diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py
index ce6f3e2f3271e486b2a30e8f1e418bc724818dc9..30d31c5630fcff0a6860e646f4989d201aa73719 100755
--- a/bob/learn/tensorflow/dataset/__init__.py
+++ b/bob/learn/tensorflow/dataset/__init__.py
@@ -1,12 +1,12 @@
 import tensorflow as tf
 import numpy
 
-DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string),
-                   'train/label': tf.FixedLenFeature([], tf.int64)}
+DEFAULT_FEATURE = {'data': tf.FixedLenFeature([], tf.string),
+                   'label': tf.FixedLenFeature([], tf.int64),
+                   'key': tf.FixedLenFeature([], tf.string)}
 
 
-
-def append_image_augmentation(image, gray_scale=False, 
+def append_image_augmentation(image, gray_scale=False,
                               output_shape=None,
                               random_flip=False,
                               random_brightness=False,
@@ -15,11 +15,11 @@ def append_image_augmentation(image, gray_scale=False,
                               per_image_normalization=True):
     """
     Append to the current tensor some random image augmentation operation
-    
+
     **Parameters**
        gray_scale:
           Convert to gray scale?
-          
+
        output_shape:
           If set, will randomly crop the image given the output shape
 
@@ -37,16 +37,16 @@ def append_image_augmentation(image, gray_scale=False,
 
        per_image_normalization:
            Linearly scales image to have zero mean and unit norm.
-       
+
     """
 
     # Casting to float32
     image = tf.cast(image, tf.float32)
 
     if output_shape is not None:
-        assert len(output_shape) == 2        
+        assert len(output_shape) == 2
         image = tf.image.resize_image_with_crop_or_pad(image, output_shape[0], output_shape[1])
-        
+
     if random_flip:
         image = tf.image.random_flip_left_right(image)
 
@@ -68,17 +68,17 @@ def append_image_augmentation(image, gray_scale=False,
         image = tf.image.per_image_standardization(image)
 
     return image
-    
-    
+
+
 def siamease_pairs_generator(input_data, input_labels):
     """
     Giving a list of samples and a list of labels, it dumps a series of
     pairs for siamese nets.
-    
+
     **Parameters**
 
       input_data: List of whatever representing the data samples
-      
+
       input_labels: List of the labels (needs to be in EXACT same order as input_data)
     """
 
@@ -86,7 +86,7 @@ def siamease_pairs_generator(input_data, input_labels):
     left_data = []
     right_data = []
     labels = []
-    
+
     def append(left, right, label):
         """
         Just appending one element in each list
@@ -97,8 +97,8 @@ def siamease_pairs_generator(input_data, input_labels):
 
     possible_labels = list(set(input_labels))
     input_data = numpy.array(input_data)
-    input_labels = numpy.array(input_labels)    
-    total_samples = input_data.shape[0] 
+    input_labels = numpy.array(input_labels)
+    total_samples = input_data.shape[0]
 
     # Filtering the samples by label and shuffling all the indexes
     indexes_per_labels = dict()
@@ -107,7 +107,7 @@ def siamease_pairs_generator(input_data, input_labels):
         numpy.random.shuffle(indexes_per_labels[l])
 
     left_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
-    right_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)       
+    right_possible_indexes = numpy.random.choice(possible_labels, total_samples, replace=True)
 
     genuine = True
     for i in range(total_samples):
@@ -142,6 +142,67 @@ def siamease_pairs_generator(input_data, input_labels):
                 append(left, right, 1)
 
 
-        genuine = not genuine    
+        genuine = not genuine
     return left_data, right_data, labels
 
+
+def blocks_tensorflow(images, block_size):
+    """Return all non-overlapping blocks of an image using tensorflow
+    operations.
+
+    Parameters
+    ----------
+    images : :any:`tf.Tensor`
+        The input color images. It is assumed that the image has a shape of
+        [?, H, W, C].
+    block_size : (int, int)
+        A tuple of two integers indicating the block size.
+
+    Returns
+    -------
+    blocks : :any:`tf.Tensor`
+        All the blocks in the batch dimension. The output will be of
+        size [?, block_size[0], block_size[1], C].
+    n_blocks : int
+        The number of blocks that was obtained per image.
+    """
+    # normalize block_size
+    block_size = [1] + list(block_size) + [1]
+    output_size = list(block_size)
+    output_size[0] = -1
+    # extract image patches for each color space:
+    output = []
+    for i in range(3):
+        blocks = tf.extract_image_patches(
+            images[:, :, :, i:i + 1], block_size, block_size, [1, 1, 1, 1],
+            "VALID")
+        if i == 0:
+            n_blocks = int(numpy.prod(blocks.shape[1:3]))
+        blocks = tf.reshape(blocks, output_size)
+        output.append(blocks)
+    # concatenate the colors back
+    output = tf.concat(output, axis=3)
+    return output, n_blocks
+
+
+def tf_repeat(tensor, repeats):
+    """
+    Parameters
+    ----------
+    tensor
+        A Tensor. 1-D or higher.
+    repeats
+        A list. Number of repeat for each dimension, length must be the same as
+        the number of dimensions in input
+
+    Returns
+    -------
+    A Tensor. Has the same type as input. Has the shape of tensor.shape *
+    repeats
+    """
+    with tf.variable_scope("repeat"):
+        expanded_tensor = tf.expand_dims(tensor, -1)
+        multiples = [1] + repeats
+        tiled_tensor = tf.tile(expanded_tensor, multiples=multiples)
+        repeated_tesnor = tf.reshape(tiled_tensor, tf.shape(tensor) * repeats)
+    return repeated_tesnor
diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index 8e4713110e265dbe6969899d680c6d668301ea8f..f3458ba9271b57f2afbf5376acc61ec3e859d7af 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -6,31 +6,31 @@ from . import append_image_augmentation, DEFAULT_FEATURE
 def example_parser(serialized_example, feature, data_shape, data_type):
     """
     Parses a single tf.Example into image and label tensors.
-    
+
     """
     # Decode the record read by the reader
     features = tf.parse_single_example(serialized_example, features=feature)
     # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features['train/data'], data_type)
+    image = tf.decode_raw(features['data'], data_type)
     # Cast label data into int64
-    label = tf.cast(features['train/label'], tf.int64)
+    label = tf.cast(features['label'], tf.int64)
     # Reshape image data into the original shape
     image = tf.reshape(image, data_shape)
-    return image, label
+    key = tf.cast(features['key'], tf.string)
+    return image, label, key
 
 
 def image_augmentation_parser(serialized_example, feature, data_shape, data_type,
-                              gray_scale=False, 
+                              gray_scale=False,
                               output_shape=None,
                               random_flip=False,
                               random_brightness=False,
                               random_contrast=False,
                               random_saturation=False,
                               per_image_normalization=True):
-
     """
     Parses a single tf.Example into image and label tensors.
-    
+
     """
     # Decode the record read by the reader
     features = tf.parse_single_example(serialized_example, features=feature)
@@ -39,8 +39,8 @@ def image_augmentation_parser(serialized_example, feature, data_shape, data_type
 
     # Reshape image data into the original shape
     image = tf.reshape(image, data_shape)
-    
-    #Applying image augmentation
+
+    # Applying image augmentation
     image = append_image_augmentation(image, gray_scale=gray_scale,
                                       output_shape=output_shape,
                                       random_flip=random_flip,
@@ -48,7 +48,7 @@ def image_augmentation_parser(serialized_example, feature, data_shape, data_type
                                       random_contrast=random_contrast,
                                       random_saturation=random_saturation,
                                       per_image_normalization=per_image_normalization)
-    
+
     # Cast label data into int64
     label = tf.cast(features['train/label'], tf.int64)
     return image, label
@@ -56,12 +56,11 @@ def image_augmentation_parser(serialized_example, feature, data_shape, data_type
 
 def read_and_decode(filename_queue, data_shape, data_type=tf.float32,
                     feature=None):
-                    
     """
     Simples parse possible for a tfrecord.
     It assumes that you have the pair **train/data** and **train/label**
     """
-                    
+
     if feature is None:
         feature = DEFAULT_FEATURE
     # Define a reader and read the next record
@@ -74,22 +73,22 @@ def create_dataset_from_records(tfrecord_filenames, data_shape, data_type,
                                 feature=None):
     """
     Create dataset from a list of tf-record files
-    
+
     **Parameters**
-    
-       tfrecord_filenames: 
+
+       tfrecord_filenames:
           List containing the tf-record paths
 
        data_shape:
           Samples shape saved in the tf-record
-          
+
        data_type:
           tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
-          
+
        feature:
-    
+
     """
-                                
+
     if feature is None:
         feature = DEFAULT_FEATURE
     dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames)
@@ -100,39 +99,38 @@ def create_dataset_from_records(tfrecord_filenames, data_shape, data_type,
 
 
 def create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape, data_type,
-                                feature=None,
-                                gray_scale=False, 
-                                output_shape=None,
-                                random_flip=False,
-                                random_brightness=False,
-                                random_contrast=False,
-                                random_saturation=False,
-                                per_image_normalization=True):
+                                                  feature=None,
+                                                  gray_scale=False,
+                                                  output_shape=None,
+                                                  random_flip=False,
+                                                  random_brightness=False,
+                                                  random_contrast=False,
+                                                  random_saturation=False,
+                                                  per_image_normalization=True):
     """
     Create dataset from a list of tf-record files
-    
+
     **Parameters**
-    
-       tfrecord_filenames: 
+
+       tfrecord_filenames:
           List containing the tf-record paths
 
        data_shape:
           Samples shape saved in the tf-record
-          
+
        data_type:
           tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
-          
+
        feature:
-    
+
     """
-                                
-                                
+
     if feature is None:
         feature = DEFAULT_FEATURE
     dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames)
     parser = partial(image_augmentation_parser, feature=feature, data_shape=data_shape,
                      data_type=data_type,
-                     gray_scale=gray_scale, 
+                     gray_scale=gray_scale,
                      output_shape=output_shape,
                      random_flip=random_flip,
                      random_brightness=random_brightness,
@@ -144,40 +142,40 @@ def create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape
 
 
 def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type,
-                                              batch_size, epochs=None, buffer_size=10**3,
-                                              gray_scale=False, 
-                                              output_shape=None,
-                                              random_flip=False,
-                                              random_brightness=False,
-                                              random_contrast=False,
-                                              random_saturation=False,
-                                              per_image_normalization=True):
+                                               batch_size, epochs=None, buffer_size=10**3,
+                                               gray_scale=False,
+                                               output_shape=None,
+                                               random_flip=False,
+                                               random_brightness=False,
+                                               random_contrast=False,
+                                               random_saturation=False,
+                                               per_image_normalization=True):
     """
     Dump random batches from a list of tf-record files and applies some image augmentation
 
     **Parameters**
 
-       tfrecord_filenames: 
+       tfrecord_filenames:
           List containing the tf-record paths
 
        data_shape:
           Samples shape saved in the tf-record
-          
+
        data_type:
           tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
-     
+
        batch_size:
           Size of the batch
-          
+
        epochs:
            Number of epochs to be batched
-       
+
        buffer_size:
             Size of the shuffle bucket
 
        gray_scale:
           Convert to gray scale?
-          
+
        output_shape:
           If set, will randomly crop the image given the output shape
 
@@ -194,20 +192,20 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, d
            Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
 
        per_image_normalization:
-           Linearly scales image to have zero mean and unit norm.            
-     
-    """                            
+           Linearly scales image to have zero mean and unit norm.
+
+    """
 
     dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape,
-                                          data_type,
-                                          gray_scale=gray_scale, 
-                                          output_shape=output_shape,
-                                          random_flip=random_flip,
-                                          random_brightness=random_brightness,
-                                          random_contrast=random_contrast,
-                                          random_saturation=random_saturation,
-                                          per_image_normalization=per_image_normalization)
-                                          
+                                                            data_type,
+                                                            gray_scale=gray_scale,
+                                                            output_shape=output_shape,
+                                                            random_flip=random_flip,
+                                                            random_brightness=random_brightness,
+                                                            random_contrast=random_contrast,
+                                                            random_saturation=random_saturation,
+                                                            per_image_normalization=per_image_normalization)
+
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
 
     data, labels = dataset.make_one_shot_iterator().get_next()
@@ -221,25 +219,25 @@ def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type,
 
     **Parameters**
 
-       tfrecord_filenames: 
+       tfrecord_filenames:
           List containing the tf-record paths
 
        data_shape:
           Samples shape saved in the tf-record
-          
+
        data_type:
           tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
-     
+
        batch_size:
           Size of the batch
-          
+
        epochs:
            Number of epochs to be batched
-       
+
        buffer_size:
             Size of the shuffle bucket
-     
-    """                            
+
+    """
 
     dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
                                           data_type)
@@ -256,26 +254,25 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
 
     **Parameters**
 
-       tfrecord_filenames: 
+       tfrecord_filenames:
           List containing the tf-record paths
 
        data_shape:
           Samples shape saved in the tf-record
-          
+
        data_type:
           tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
-     
+
        batch_size:
           Size of the batch
-          
+
        epochs:
            Number of epochs to be batched
-     
-    """                             
+
+    """
     dataset = create_dataset_from_records(tfrecord_filenames, data_shape,
                                           data_type)
     dataset = dataset.batch(batch_size).repeat(epochs)
 
     data, labels = dataset.make_one_shot_iterator().get_next()
     return data, labels
-
diff --git a/bob/learn/tensorflow/network/SimpleCNN.py b/bob/learn/tensorflow/network/SimpleCNN.py
index 01bec65179bc7c5f96365afb4b2e403bcaff0782..eb7a98a8c61b89f8cd745dfc5e07c2ec1583a7fd 100644
--- a/bob/learn/tensorflow/network/SimpleCNN.py
+++ b/bob/learn/tensorflow/network/SimpleCNN.py
@@ -1,54 +1,57 @@
 import tensorflow as tf
+from ..utils import get_available_gpus, to_channels_first
 
 
-def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN):
-    # TODO: figure out a way to accept different input sizes
+def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
+                 kernerl_size=(3, 3), n_classes=2):
+    data_format = 'channels_last'
+    if len(get_available_gpus()) != 0:
+        # When running on GPU, transpose the data from channels_last (NHWC) to
+        # channels_first (NCHW) to improve performance. See
+        # https://www.tensorflow.org/performance/performance_guide#data_formats
+        input_layer = to_channels_first('input_layer')
+        data_format = 'channels_first'
 
     # Convolutional Layer #1
-    # Computes 32 features using a 5x5 filter with ReLU activation.
+    # Computes 32 features using a kernerl_size filter with ReLU activation.
     # Padding is added to preserve width and height.
-    # Input Tensor Shape: [batch_size, 50, 1024, 1]
-    # Output Tensor Shape: [batch_size, 50, 1024, 32]
     conv1 = tf.layers.conv2d(
         inputs=input_layer,
         filters=32,
-        kernel_size=[5, 5],
+        kernel_size=kernerl_size,
         padding="same",
-        activation=tf.nn.relu)
+        activation=tf.nn.relu,
+        data_format=data_format)
 
     # Pooling Layer #1
     # First max pooling layer with a 2x2 filter and stride of 2
     # Input Tensor Shape: [batch_size, 50, 1024, 32]
     # Output Tensor Shape: [batch_size, 25, 512, 32]
-    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
+    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2,
+                                    data_format=data_format)
 
     # Convolutional Layer #2
-    # Computes 64 features using a 5x5 filter.
+    # Computes 64 features using a kernerl_size filter.
     # Padding is added to preserve width and height.
-    # Input Tensor Shape: [batch_size, 25, 512, 32]
-    # Output Tensor Shape: [batch_size, 25, 512, 64]
     conv2 = tf.layers.conv2d(
         inputs=pool1,
         filters=64,
-        kernel_size=[5, 5],
+        kernel_size=kernerl_size,
         padding="same",
-        activation=tf.nn.relu)
+        activation=tf.nn.relu,
+        data_format=data_format)
 
     # Pooling Layer #2
     # Second max pooling layer with a 2x2 filter and stride of 2
-    # Input Tensor Shape: [batch_size, 25, 512, 64]
-    # Output Tensor Shape: [batch_size, 12, 256, 64]
-    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
+    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2,
+                                    data_format=data_format)
 
     # Flatten tensor into a batch of vectors
-    # Input Tensor Shape: [batch_size, 12, 256, 64]
-    # Output Tensor Shape: [batch_size, 12 * 256 * 64]
-    pool2_flat = tf.reshape(pool2, [-1, 12 * 256 * 64])
+    # TODO: use tf.layers.flatten in tensorflow 1.4 above
+    pool2_flat = tf.contrib.layers.flatten(pool2)
 
     # Dense Layer
     # Densely connected layer with 1024 neurons
-    # Input Tensor Shape: [batch_size, 12 * 256 * 64]
-    # Output Tensor Shape: [batch_size, 1024]
     dense = tf.layers.dense(
         inputs=pool2_flat, units=1024, activation=tf.nn.relu)
 
@@ -59,24 +62,30 @@ def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN):
     # Logits layer
     # Input Tensor Shape: [batch_size, 1024]
     # Output Tensor Shape: [batch_size, 2]
-    logits = tf.layers.dense(inputs=dropout, units=2)
+    logits = tf.layers.dense(inputs=dropout, units=n_classes)
 
     return logits
 
 
-def model_fn(features, labels, mode, params, config):
+def model_fn(features, labels, mode, params=None, config=None):
     """Model function for CNN."""
     params = params or {}
-    learning_rate = params.get('learning_rate', 0.00001)
+    learning_rate = params.get('learning_rate', 1e-5)
+    kernerl_size = params.get('kernerl_size', (3, 3))
+    n_classes = params.get('n_classes', 2)
 
-    logits = architecture(features, mode)
+    data = features['data']
+    keys = features['keys']
+    logits = architecture(
+        data, mode, kernerl_size=kernerl_size, n_classes=n_classes)
 
     predictions = {
         # Generate predictions (for PREDICT and EVAL mode)
         "classes": tf.argmax(input=logits, axis=1),
         # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
         # `logging_hook`.
-        "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
+        "probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
+        'keys': keys,
     }
     if mode == tf.estimator.ModeKeys.PREDICT:
         return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py
index dc1cd120b6db31c0857415d44de7964bf722d53d..d69ddb43ca89da4e7680466de77f64535ae3f99a 100644
--- a/bob/learn/tensorflow/script/db_to_tfrecords.py
+++ b/bob/learn/tensorflow/script/db_to_tfrecords.py
@@ -62,7 +62,7 @@ An example for mnist would be::
 An example for bob.bio.base would be::
 
     from bob.bio.base.test.dummy.database import database
-    from bob.bio.base.test.dummy.preprocessor import preprocessor
+    from bob.bio.base.utils import read_original_data
 
     groups = 'dev'
 
@@ -78,7 +78,7 @@ An example for bob.bio.base would be::
 
 
     def reader(biofile):
-        data = preprocessor.read_original_data(
+        data = read_original_data(
             biofile, database.original_directory, database.original_extension)
         label = file_to_label(biofile)
         key = biofile.path
diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py
index d3a9b9943e82907629491205f0407cd149e43437..6fdaeda95fb143844202bf21c5efc5808b250f1f 100755
--- a/bob/learn/tensorflow/utils/util.py
+++ b/bob/learn/tensorflow/utils/util.py
@@ -1,11 +1,11 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 11 May 2016 09:39:36 CEST 
+# @date: Wed 11 May 2016 09:39:36 CEST
 
 import numpy
 import tensorflow as tf
-numpy.random.seed(10)
+from tensorflow.python.client import device_lib
 
 
 def compute_euclidean_distance(x, y):
@@ -13,7 +13,7 @@ def compute_euclidean_distance(x, y):
     Computes the euclidean distance between two tensorflow variables
     """
 
-    with tf.name_scope('euclidean_distance') as scope:
+    with tf.name_scope('euclidean_distance'):
         d = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x, y)), 1))
         return d
 
@@ -34,14 +34,15 @@ def load_mnist(perc_train=0.9):
     numpy.random.shuffle(indexes)
 
     # Spliting train and validation
-    n_train = int(perc_train*indexes.shape[0])
+    n_train = int(perc_train * indexes.shape[0])
     n_validation = total_samples - n_train
 
     train_data = data[0:n_train, :].astype("float32") * 0.00390625
     train_labels = labels[0:n_train]
 
-    validation_data = data[n_train:n_train+n_validation, :].astype("float32") * 0.00390625
-    validation_labels = labels[n_train:n_train+n_validation]
+    validation_data = data[n_train:n_train +
+                           n_validation, :].astype("float32") * 0.00390625
+    validation_labels = labels[n_train:n_train + n_validation]
 
     return train_data, train_labels, validation_data, validation_labels
 
@@ -59,17 +60,18 @@ def create_mnist_tfrecord(tfrecords_filename, data, labels, n_samples=6000):
     for i in range(n_samples):
         img = data[i]
         img_raw = img.tostring()
-        
+
         feature = {'train/data': _bytes_feature(img_raw),
                    'train/label': _int64_feature(labels[i])
-                  }
-        
+                   }
+
         example = tf.train.Example(features=tf.train.Features(feature=feature))
         writer.write(example.SerializeToString())
     writer.close()
 
 
-def compute_eer(data_train, labels_train, data_validation, labels_validation, n_classes):
+def compute_eer(data_train, labels_train, data_validation, labels_validation,
+                n_classes):
     import bob.measure
     from scipy.spatial.distance import cosine
 
@@ -87,13 +89,15 @@ def compute_eer(data_train, labels_train, data_validation, labels_validation, n_
         # Positive scoring
         indexes = labels_validation == i
         positive_data = data_validation[indexes, :]
-        p = [cosine(models[i], positive_data[j]) for j in range(positive_data.shape[0])]
+        p = [cosine(models[i], positive_data[j])
+             for j in range(positive_data.shape[0])]
         positive_scores = numpy.hstack((positive_scores, p))
 
         # negative scoring
         indexes = labels_validation != i
         negative_data = data_validation[indexes, :]
-        n = [cosine(models[i], negative_data[j]) for j in range(negative_data.shape[0])]
+        n = [cosine(models[i], negative_data[j])
+             for j in range(negative_data.shape[0])]
         negative_scores = numpy.hstack((negative_scores, n))
 
     # Computing performance based on EER
@@ -107,7 +111,8 @@ def compute_eer(data_train, labels_train, data_validation, labels_validation, n_
     return eer
 
 
-def compute_accuracy(data_train, labels_train, data_validation, labels_validation, n_classes):
+def compute_accuracy(data_train, labels_train, data_validation,
+                     labels_validation, n_classes):
     from scipy.spatial.distance import cosine
 
     # Creating client models
@@ -120,7 +125,7 @@ def compute_accuracy(data_train, labels_train, data_validation, labels_validatio
     tp = 0
     for i in range(data_validation.shape[0]):
 
-        d = data_validation[i,:]
+        d = data_validation[i, :]
         l = labels_validation[i]
 
         scores = [cosine(m, d) for m in models]
@@ -130,21 +135,26 @@ def compute_accuracy(data_train, labels_train, data_validation, labels_validatio
             tp += 1
 
     return (float(tp) / data_validation.shape[0]) * 100
-    
-    
+
+
 def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"):
     """
     """
     import tensorflow as tf
     from bob.learn.tensorflow.utils.session import Session
-    
-    session = Session.instance(new=False).session    
-    inference_graph = architecture.compute_graph(architecture.inference_placeholder, feature_layer=feature_layer, training=False)
+
+    session = Session.instance(new=False).session
+    inference_graph = architecture.compute_graph(
+        architecture.inference_placeholder, feature_layer=feature_layer,
+        training=False)
 
     embeddings = numpy.zeros(shape=(image.shape[0], embbeding_dim))
     for i in range(image.shape[0]):
-        feed_dict = {architecture.inference_placeholder: image[i:i+1, :,:,:]}
-        embedding = session.run([tf.nn.l2_normalize(inference_graph, 1, 1e-10)], feed_dict=feed_dict)[0]
+        feed_dict = {
+            architecture.inference_placeholder: image[i:i + 1, :, :, :]}
+        embedding = session.run(
+            [tf.nn.l2_normalize(inference_graph, 1, 1e-10)],
+            feed_dict=feed_dict)[0]
         embedding = numpy.reshape(embedding, numpy.prod(embedding.shape[1:]))
         embeddings[i] = embedding
 
@@ -157,15 +167,15 @@ def cdist(A):
     as in scipy.spation.distance.cdist
     """
     with tf.variable_scope('Pairwisedistance'):
-        #ones_1 = tf.ones(shape=(1, A.shape.as_list()[0]))
-        ones_1 = tf.reshape(tf.cast(tf.ones_like(A), tf.float32)[:, 0], [1, -1])
+        ones_1 = tf.reshape(
+            tf.cast(tf.ones_like(A), tf.float32)[:, 0], [1, -1])
         p1 = tf.matmul(
             tf.expand_dims(tf.reduce_sum(tf.square(A), 1), 1),
             ones_1
         )
 
-        #ones_2 = tf.ones(shape=(A.shape.as_list()[0], 1))
-        ones_2 = tf.reshape(tf.cast(tf.ones_like(A), tf.float32)[:, 0], [-1, 1])
+        ones_2 = tf.reshape(
+            tf.cast(tf.ones_like(A), tf.float32)[:, 0], [-1, 1])
         p2 = tf.transpose(tf.matmul(
             tf.reshape(tf.reduce_sum(tf.square(A), 1), shape=[-1, 1]),
             ones_2,
@@ -181,7 +191,8 @@ def predict_using_tensors(embedding, labels, num=None):
     embeddings using tensors
     """
 
-    # Fitting the main diagonal with infs (removing comparisons with the same sample)
+    # Fitting the main diagonal with infs (removing comparisons with the same
+    # sample)
     inf = tf.cast(tf.ones_like(labels), tf.float32) * numpy.inf
 
     distances = cdist(embedding)
@@ -192,30 +203,34 @@ def predict_using_tensors(embedding, labels, num=None):
 
 def compute_embedding_accuracy_tensors(embedding, labels, num=None):
     """
-    Compute the accuracy through exhaustive comparisons between the embeddings using tensors
+    Compute the accuracy through exhaustive comparisons between the embeddings
+    using tensors
     """
 
-    # Fitting the main diagonal with infs (removing comparisons with the same sample)
+    # Fitting the main diagonal with infs (removing comparisons with the same
+    # sample)
     predictions = predict_using_tensors(embedding, labels, num=num)
-    matching = [tf.equal(p, l) for p, l in zip(tf.unstack(predictions, num=num), tf.unstack(labels, num=num))]
+    matching = [tf.equal(p, l) for p, l in zip(tf.unstack(
+        predictions, num=num), tf.unstack(labels, num=num))]
 
-    return tf.reduce_sum(tf.cast(matching, tf.uint8))/len(predictions)
+    return tf.reduce_sum(tf.cast(matching, tf.uint8)) / len(predictions)
 
 
 def compute_embedding_accuracy(embedding, labels):
     """
-    Compute the accuracy through exhaustive comparisons between the embeddings 
+    Compute the accuracy through exhaustive comparisons between the embeddings
     """
 
     from scipy.spatial.distance import cdist
-    
+
     distances = cdist(embedding, embedding)
-    
+
     n_samples = embedding.shape[0]
 
-    # Fitting the main diagonal with infs (removing comparisons with the same sample)
+    # Fitting the main diagonal with infs (removing comparisons with the same
+    # sample)
     numpy.fill_diagonal(distances, numpy.inf)
-    
+
     indexes = distances.argmin(axis=1)
 
     # Computing the argmin excluding comparisons with the same samples
@@ -226,8 +241,84 @@ def compute_embedding_accuracy(embedding, labels):
     # Getting the original positions of the indexes in the 1-axis
     #corrected_indexes = [ i if i<j else i+1 for i, j in zip(valid_indexes, range(n_samples))]
 
-    matching = [ labels[i]==labels[j] for i,j in zip(range(n_samples), indexes)]    
-    accuracy = sum(matching)/float(n_samples)
-    
+    matching = [labels[i] == labels[j]
+                for i, j in zip(range(n_samples), indexes)]
+    accuracy = sum(matching) / float(n_samples)
+
     return accuracy
 
+
+def get_available_gpus():
+    """Returns the number of GPU devices that are available.
+
+    Returns
+    -------
+    [str]
+        The names of available GPU devices.
+    """
+    local_device_protos = device_lib.list_local_devices()
+    return [x.name for x in local_device_protos if x.device_type == 'GPU']
+
+
+def to_channels_last(image):
+    """Converts the image to channel_last format. This is the same format as in
+    matplotlib, skimage, and etc.
+
+    Parameters
+    ----------
+    image : :any:`tf.Tensor`
+        At least a 3 dimensional image. If the dimension is more than 3, the
+        last 3 dimensions are assumed to be [C, H, W].
+
+    Returns
+    -------
+    image : :any:`tf.Tensor`
+        The image in [..., H, W, C] format.
+
+    Raises
+    ------
+    ValueError
+        If dim of image is less than 3.
+    """
+    ndim = len(image.shape)
+    if ndim < 3:
+        raise ValueError("The image needs to be at least 3 dimensional but it "
+                         "was {}".format(ndim))
+    axis_order = [1, 2, 0]
+    shift = ndim - 3
+    axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
+    return tf.transpose(image, axis_order)
+
+
+def to_channels_first(image):
+    """Converts the image to channel_first format. This is the same format as
+    in bob.io.image and bob.io.video.
+
+    Parameters
+    ----------
+    image : :any:`tf.Tensor`
+        At least a 3 dimensional image. If the dimension is more than 3, the
+        last 3 dimensions are assumed to be [H, W, C].
+
+    Returns
+    -------
+    image : :any:`tf.Tensor`
+        The image in [..., C, H, W] format.
+
+    Raises
+    ------
+    ValueError
+        If dim of image is less than 3.
+    """
+    ndim = len(image.shape)
+    if ndim < 3:
+        raise ValueError("The image needs to be at least 3 dimensional but it "
+                         "was {}".format(ndim))
+    axis_order = [2, 0, 1]
+    shift = ndim - 3
+    axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
+    return tf.transpose(image, axis_order)
+
+
+to_skimage = to_matplotlib = to_channels_last
+to_bob = to_channels_first