diff --git a/bob/learn/tensorflow/data/__init__.py b/bob/learn/tensorflow/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d31417512b35396459e77db40c3448d4372a34e
--- /dev/null
+++ b/bob/learn/tensorflow/data/__init__.py
@@ -0,0 +1,23 @@
+from .generator import Generator, dataset_using_generator
+from .tfrecords import dataset_to_tfrecord, dataset_from_tfrecord, TFRECORDS_EXT
+
+# gets sphinx autodoc done right - don't remove it
+def __appropriate__(*args):
+    """Says object was actually declared here, an not on the import module.
+
+    Parameters:
+
+      *args: An iterable of objects to modify
+
+    Resolves `Sphinx referencing issues
+    <https://github.com/sphinx-doc/sphinx/issues/3048>`
+    """
+
+    for obj in args:
+        obj.__module__ = __name__
+
+
+__appropriate__(
+    Generator,
+)
+__all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/bob/learn/tensorflow/dataset/generator.py b/bob/learn/tensorflow/data/generator.py
similarity index 100%
rename from bob/learn/tensorflow/dataset/generator.py
rename to bob/learn/tensorflow/data/generator.py
diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/data/tfrecords.py
similarity index 62%
rename from bob/learn/tensorflow/dataset/tfrecords.py
rename to bob/learn/tensorflow/data/tfrecords.py
index 1000f273906cefe234d7b6a050ac2299363e1536..bacf49cf8c10bbecd8f613a52033c104c212290a 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/data/tfrecords.py
@@ -5,15 +5,10 @@ from __future__ import division
 from __future__ import print_function
 
 import json
-import logging
-import os
-from functools import partial
 
 import tensorflow as tf
 
-from . import DEFAULT_FEATURE
 
-logger = logging.getLogger(__name__)
 TFRECORDS_EXT = ".tfrecords"
 
 
@@ -140,89 +135,3 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
         return tf.nest.pack_sequence_as(meta["output_types"], args)
 
     return raw_dataset.map(_parse_function)
-
-
-# def write_a_sample(writer, data, label, key, feature=None, size_estimate=False):
-#     if feature is None:
-#         feature = {
-#             "data": bytes_feature(data.tostring()),
-#             "label": int64_feature(label),
-#             "key": bytes_feature(key),
-#         }
-
-#     example = tf.train.Example(features=tf.train.Features(feature=feature))
-#     example = example.SerializeToString()
-#     if not size_estimate:
-#         writer.write(example)
-#     return sys.getsizeof(example)
-
-
-# def example_parser(serialized_example, feature, data_shape, data_type):
-#     """
-#     Parses a single tf.Example into image and label tensors.
-
-#     """
-#     # Decode the record read by the reader
-#     features = tf.io.parse_single_example(
-#         serialized=serialized_example, features=feature
-#     )
-#     # Convert the image data from string back to the numbers
-#     image = tf.io.decode_raw(features["data"], data_type)
-#     # Cast label data into int64
-#     label = tf.cast(features["label"], tf.int64)
-#     # Reshape image data into the original shape
-#     image = tf.reshape(image, data_shape)
-#     key = tf.cast(features["key"], tf.string)
-#     return image, label, key
-
-
-# def image_augmentation_parser(
-#     serialized_example,
-#     feature,
-#     data_shape,
-#     data_type,
-#     gray_scale=False,
-#     output_shape=None,
-#     random_flip=False,
-#     random_brightness=False,
-#     random_contrast=False,
-#     random_saturation=False,
-#     random_rotate=False,
-#     per_image_normalization=True,
-#     random_gamma=False,
-#     random_crop=False,
-# ):
-#     """
-#     Parses a single tf.Example into image and label tensors.
-
-#     """
-#     # Decode the record read by the reader
-#     features = tf.io.parse_single_example(
-#         serialized=serialized_example, features=feature
-#     )
-#     # Convert the image data from string back to the numbers
-#     image = tf.io.decode_raw(features["data"], data_type)
-
-#     # Reshape image data into the original shape
-#     image = tf.reshape(image, data_shape)
-
-#     # Applying image augmentation
-#     image = append_image_augmentation(
-#         image,
-#         gray_scale=gray_scale,
-#         output_shape=output_shape,
-#         random_flip=random_flip,
-#         random_brightness=random_brightness,
-#         random_contrast=random_contrast,
-#         random_saturation=random_saturation,
-#         random_rotate=random_rotate,
-#         per_image_normalization=per_image_normalization,
-#         random_gamma=random_gamma,
-#         random_crop=random_crop,
-#     )
-
-#     # Cast label data into int64
-#     label = tf.cast(features["label"], tf.int64)
-#     key = tf.cast(features["key"], tf.string)
-
-#     return image, label, key
diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py
deleted file mode 100644
index 0612be7aa5ceab6281d2a5975b4617eca041feb8..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/dataset/__init__.py
+++ /dev/null
@@ -1,421 +0,0 @@
-import os
-
-import numpy
-import tensorflow as tf
-
-import bob.io.base
-
-DEFAULT_FEATURE = {
-    "data": tf.io.FixedLenFeature([], tf.string),
-    "label": tf.io.FixedLenFeature([], tf.int64),
-    "key": tf.io.FixedLenFeature([], tf.string),
-}
-
-
-def from_hdf5file_to_tensor(filename):
-    import bob.io.image
-
-    data = bob.io.image.to_matplotlib(bob.io.base.load(filename))
-
-    # reshaping to ndim == 3
-    if data.ndim == 2:
-        data = numpy.reshape(data, (data.shape[0], data.shape[1], 1))
-    data = data.astype("float32")
-
-    return data
-
-
-def from_filename_to_tensor(filename, extension=None):
-    """
-    Read a file and it convert it to tensor.
-
-    If the file extension is something that tensorflow understands (.jpg, .bmp, .tif,...),
-    it uses the `tf.image.decode_image` otherwise it uses `bob.io.base.load`
-    """
-
-    if extension == "hdf5":
-        return tf.compat.v1.py_func(from_hdf5file_to_tensor, [filename], [tf.float32])
-    else:
-        return tf.cast(tf.image.decode_image(tf.io.read_file(filename)), tf.float32)
-
-
-# def append_image_augmentation(
-#     image,
-#     gray_scale=False,
-#     output_shape=None,
-#     random_flip=False,
-#     random_brightness=False,
-#     random_contrast=False,
-#     random_saturation=False,
-#     random_rotate=False,
-#     per_image_normalization=True,
-#     random_gamma=False,
-#     random_crop=False,
-# ):
-#     """
-#     Append to the current tensor some random image augmentation operation
-
-#     **Parameters**
-#        gray_scale:
-#           Convert to gray scale?
-
-#        output_shape:
-#           If set, will randomly crop the image given the output shape
-
-#        random_flip:
-#           Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
-
-#        random_brightness:
-#            Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
-
-#        random_contrast:
-#            Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
-
-#        random_saturation:
-#            Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
-
-#        random_rotate:
-#            Randomly rotate face images between -5 and 5 degrees
-
-#        per_image_normalization:
-#            Linearly scales image to have zero mean and unit norm.
-
-#     """
-
-#     # Changing the range from 0-255 to 0-1
-#     image = tf.cast(image, tf.float32) / 255
-#     # FORCING A SEED FOR THE RANDOM OPERATIONS
-#     tf.compat.v1.set_random_seed(0)
-
-#     if output_shape is not None:
-#         assert len(output_shape) == 2
-#         if random_crop:
-#             image = tf.image.random_crop(image, size=list(output_shape) + [3])
-#         else:
-#             image = tf.image.resize_with_crop_or_pad(
-#                 image, output_shape[0], output_shape[1]
-#             )
-
-#     if random_flip:
-#         image = tf.image.random_flip_left_right(image)
-
-#     if random_brightness:
-#         image = tf.image.random_brightness(image, max_delta=0.15)
-#         image = tf.clip_by_value(image, 0, 1)
-
-#     if random_contrast:
-#         image = tf.image.random_contrast(image, lower=0.85, upper=1.15)
-#         image = tf.clip_by_value(image, 0, 1)
-
-#     if random_saturation:
-#         image = tf.image.random_saturation(image, lower=0.85, upper=1.15)
-#         image = tf.clip_by_value(image, 0, 1)
-
-#     if random_gamma:
-#         image = tf.image.adjust_gamma(
-#             image, gamma=tf.random.uniform(shape=[], minval=0.85, maxval=1.15)
-#         )
-#         image = tf.clip_by_value(image, 0, 1)
-
-#     if random_rotate:
-#         # from https://stackoverflow.com/a/53855704/1286165
-#         degree = 0.08726646259971647  # math.pi * 5 /180
-#         random_angles = tf.random.uniform(shape=(1,), minval=-degree, maxval=degree)
-#         image = tf.contrib.image.transform(
-#             image,
-#             tf.contrib.image.angles_to_projective_transforms(
-#                 random_angles,
-#                 tf.cast(tf.shape(input=image)[-3], tf.float32),
-#                 tf.cast(tf.shape(input=image)[-2], tf.float32),
-#             ),
-#         )
-
-#     if gray_scale:
-#         image = tf.image.rgb_to_grayscale(image, name="rgb_to_gray")
-
-#     # normalizing data
-#     if per_image_normalization:
-#         image = tf.image.per_image_standardization(image)
-
-#     return image
-
-
-def arrange_indexes_by_label(input_labels, possible_labels):
-
-    # Shuffling all the indexes
-    indexes_per_labels = dict()
-    for l in possible_labels:
-        indexes_per_labels[l] = numpy.where(input_labels == l)[0]
-        numpy.random.shuffle(indexes_per_labels[l])
-    return indexes_per_labels
-
-
-def triplets_random_generator(input_data, input_labels):
-    """
-    Giving a list of samples and a list of labels, it dumps a series of
-    triplets for triple nets.
-
-    **Parameters**
-
-      input_data: List of whatever representing the data samples
-
-      input_labels: List of the labels (needs to be in EXACT same order as input_data)
-    """
-    anchor = []
-    positive = []
-    negative = []
-
-    def append(anchor_sample, positive_sample, negative_sample):
-        """
-        Just appending one element in each list
-        """
-        anchor.append(anchor_sample)
-        positive.append(positive_sample)
-        negative.append(negative_sample)
-
-    possible_labels = list(set(input_labels))
-    input_data = numpy.array(input_data)
-    input_labels = numpy.array(input_labels)
-    total_samples = input_data.shape[0]
-
-    indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
-
-    # searching for random triplets
-    offset_class = 0
-    for i in range(total_samples):
-
-        anchor_sample = input_data[
-            indexes_per_labels[possible_labels[offset_class]][
-                numpy.random.randint(
-                    len(indexes_per_labels[possible_labels[offset_class]])
-                )
-            ],
-            ...,
-        ]
-
-        positive_sample = input_data[
-            indexes_per_labels[possible_labels[offset_class]][
-                numpy.random.randint(
-                    len(indexes_per_labels[possible_labels[offset_class]])
-                )
-            ],
-            ...,
-        ]
-
-        # Changing the class
-        offset_class += 1
-
-        if offset_class == len(possible_labels):
-            offset_class = 0
-
-        negative_sample = input_data[
-            indexes_per_labels[possible_labels[offset_class]][
-                numpy.random.randint(
-                    len(indexes_per_labels[possible_labels[offset_class]])
-                )
-            ],
-            ...,
-        ]
-
-        append(str(anchor_sample), str(positive_sample), str(negative_sample))
-        # yield anchor, positive, negative
-    return anchor, positive, negative
-
-
-def siamease_pairs_generator(input_data, input_labels):
-    """
-    Giving a list of samples and a list of labels, it dumps a series of
-    pairs for siamese nets.
-
-    **Parameters**
-
-      input_data: List of whatever representing the data samples
-
-      input_labels: List of the labels (needs to be in EXACT same order as input_data)
-    """
-
-    # Lists that will be returned
-    left_data = []
-    right_data = []
-    labels = []
-
-    def append(left, right, label):
-        """
-        Just appending one element in each list
-        """
-        left_data.append(left)
-        right_data.append(right)
-        labels.append(label)
-
-    possible_labels = list(set(input_labels))
-    input_data = numpy.array(input_data)
-    input_labels = numpy.array(input_labels)
-    total_samples = input_data.shape[0]
-
-    # Filtering the samples by label and shuffling all the indexes
-    # indexes_per_labels = dict()
-    # for l in possible_labels:
-    #    indexes_per_labels[l] = numpy.where(input_labels == l)[0]
-    #    numpy.random.shuffle(indexes_per_labels[l])
-    indexes_per_labels = arrange_indexes_by_label(input_labels, possible_labels)
-
-    left_possible_indexes = numpy.random.choice(
-        possible_labels, total_samples, replace=True
-    )
-    right_possible_indexes = numpy.random.choice(
-        possible_labels, total_samples, replace=True
-    )
-
-    genuine = True
-    for i in range(total_samples):
-
-        if genuine:
-            # Selecting the class
-            class_index = left_possible_indexes[i]
-
-            # Now selecting the samples for the pair
-            left = input_data[
-                indexes_per_labels[class_index][
-                    numpy.random.randint(len(indexes_per_labels[class_index]))
-                ]
-            ]
-            right = input_data[
-                indexes_per_labels[class_index][
-                    numpy.random.randint(len(indexes_per_labels[class_index]))
-                ]
-            ]
-            append(left, right, 0)
-            # yield left, right, 0
-        else:
-            # Selecting the 2 classes
-            class_index = list()
-            class_index.append(left_possible_indexes[i])
-
-            # Finding the right pair
-            j = i
-            # TODO: Lame solution. Fix this
-            while (
-                j < total_samples
-            ):  # Here is an unidiretinal search for the negative pair
-                if left_possible_indexes[i] != right_possible_indexes[j]:
-                    class_index.append(right_possible_indexes[j])
-                    break
-                j += 1
-
-            if j < total_samples:
-                # Now selecting the samples for the pair
-                left = input_data[
-                    indexes_per_labels[class_index[0]][
-                        numpy.random.randint(len(indexes_per_labels[class_index[0]]))
-                    ]
-                ]
-                right = input_data[
-                    indexes_per_labels[class_index[1]][
-                        numpy.random.randint(len(indexes_per_labels[class_index[1]]))
-                    ]
-                ]
-                append(left, right, 1)
-
-        genuine = not genuine
-    return left_data, right_data, labels
-
-
-def blocks_tensorflow(images, block_size):
-    """Return all non-overlapping blocks of an image using tensorflow
-    operations.
-
-    Parameters
-    ----------
-    images : `tf.Tensor`
-        The input color images. It is assumed that the image has a shape of
-        [?, H, W, C].
-    block_size : (int, int)
-        A tuple of two integers indicating the block size.
-
-    Returns
-    -------
-    blocks : `tf.Tensor`
-        All the blocks in the batch dimension. The output will be of
-        size [?, block_size[0], block_size[1], C].
-    n_blocks : int
-        The number of blocks that was obtained per image.
-    """
-    # normalize block_size
-    block_size = [1] + list(block_size) + [1]
-    output_size = list(block_size)
-    output_size[0] = -1
-    output_size[-1] = images.shape[-1]
-    blocks = tf.image.extract_patches(
-        images, block_size, block_size, [1, 1, 1, 1], "VALID"
-    )
-    n_blocks = int(numpy.prod(blocks.shape[1:3]))
-    output = tf.reshape(blocks, output_size)
-    return output, n_blocks
-
-
-def tf_repeat(tensor, repeats):
-    """
-    Parameters
-    ----------
-    tensor
-        A Tensor. 1-D or higher.
-    repeats
-        A list. Number of repeat for each dimension, length must be the same as
-        the number of dimensions in input
-
-    Returns
-    -------
-    A Tensor. Has the same type as input. Has the shape of tensor.shape *
-    repeats
-    """
-    with tf.compat.v1.variable_scope("repeat"):
-        expanded_tensor = tf.expand_dims(tensor, -1)
-        multiples = [1] + repeats
-        tiled_tensor = tf.tile(expanded_tensor, multiples=multiples)
-        repeated_tesnor = tf.reshape(tiled_tensor, tf.shape(input=tensor) * repeats)
-    return repeated_tesnor
-
-
-def all_patches(image, label, key, size):
-    """Extracts all patches of an image
-
-    Parameters
-    ----------
-    image:
-        The image should be channels_last format and already batched.
-
-    label:
-        The label for the image
-
-    key:
-        The key for the image
-
-    size: (int, int)
-        The height and width of the blocks.
-
-    Returns
-    -------
-    blocks:
-       The non-overlapping blocks of size from image and labels and keys are
-       repeated.
-
-    label:
-
-    key:
-    """
-    blocks, n_blocks = blocks_tensorflow(image, size)
-
-    # duplicate label and key as n_blocks
-    def repeats(shape):
-        r = shape.as_list()
-        for i in range(len(r)):
-            if i == 0:
-                r[i] = n_blocks
-            else:
-                r[i] = 1
-        return r
-
-    label = tf_repeat(label, repeats(label.shape))
-    key = tf_repeat(key, repeats(key.shape))
-
-    return blocks, label, key
diff --git a/bob/learn/tensorflow/dataset/image.py b/bob/learn/tensorflow/dataset/image.py
deleted file mode 100644
index d805a44fc4992f6517da98590f9abf829fcbdaee..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/dataset/image.py
+++ /dev/null
@@ -1,233 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-from functools import partial
-
-import tensorflow as tf
-
-from . import append_image_augmentation
-from . import from_filename_to_tensor
-
-
-def shuffle_data_and_labels_image_augmentation(
-    filenames,
-    labels,
-    data_shape,
-    data_type,
-    batch_size,
-    epochs=None,
-    buffer_size=10 ** 3,
-    gray_scale=False,
-    output_shape=None,
-    random_flip=False,
-    random_brightness=False,
-    random_contrast=False,
-    random_saturation=False,
-    random_rotate=False,
-    per_image_normalization=True,
-    extension=None,
-):
-    """
-    Dump random batches from a list of image paths and labels:
-
-    The list of files and labels should be in the same order e.g.
-    filenames = ['class_1_img1', 'class_1_img2', 'class_2_img1']
-    labels = [0, 0, 1]
-
-    **Parameters**
-
-       filenames:
-          List containing the path of the images
-
-       labels:
-          List containing the labels (needs to be in EXACT same order as filenames)
-
-       data_shape:
-          Samples shape saved in the tf-record
-
-       data_type:
-          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
-
-       batch_size:
-          Size of the batch
-
-       epochs:
-           Number of epochs to be batched
-
-       buffer_size:
-            Size of the shuffle bucket
-
-       gray_scale:
-          Convert to gray scale?
-
-       output_shape:
-          If set, will randomly crop the image given the output shape
-
-       random_flip:
-          Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
-
-       random_brightness:
-           Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
-
-       random_contrast:
-           Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
-
-       random_saturation:
-           Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
-
-       random_rotate:
-           Randomly rotate face images between -5 and 5 degrees
-
-       per_image_normalization:
-           Linearly scales image to have zero mean and unit norm.
-
-       extension:
-           If None, will load files using `tf.image.decode..` if set to `hdf5`, will load with `bob.io.base.load`
-
-    """
-
-    dataset = create_dataset_from_path_augmentation(
-        filenames,
-        labels,
-        data_shape,
-        data_type,
-        gray_scale=gray_scale,
-        output_shape=output_shape,
-        random_flip=random_flip,
-        random_brightness=random_brightness,
-        random_contrast=random_contrast,
-        random_saturation=random_saturation,
-        random_rotate=random_rotate,
-        per_image_normalization=per_image_normalization,
-        extension=extension,
-    )
-
-    dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
-
-    data, labels = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
-    return data, labels
-
-
-def create_dataset_from_path_augmentation(
-    filenames,
-    labels,
-    data_shape,
-    data_type,
-    gray_scale=False,
-    output_shape=None,
-    random_flip=False,
-    random_brightness=False,
-    random_contrast=False,
-    random_saturation=False,
-    random_rotate=False,
-    per_image_normalization=True,
-    extension=None,
-):
-    """
-    Create dataset from a list of tf-record files
-
-    **Parameters**
-
-       filenames:
-          List containing the path of the images
-
-       labels:
-          List containing the labels (needs to be in EXACT same order as filenames)
-
-       data_shape:
-          Samples shape saved in the tf-record
-
-       data_type:
-          tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
-
-       feature:
-
-    """
-
-    parser = partial(
-        image_augmentation_parser,
-        data_shape=data_shape,
-        data_type=data_type,
-        gray_scale=gray_scale,
-        output_shape=output_shape,
-        random_flip=random_flip,
-        random_brightness=random_brightness,
-        random_contrast=random_contrast,
-        random_saturation=random_saturation,
-        random_rotate=random_rotate,
-        per_image_normalization=per_image_normalization,
-        extension=extension,
-    )
-
-    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
-    dataset = dataset.map(parser)
-    return dataset
-
-
-def image_augmentation_parser(
-    filename,
-    label,
-    data_shape,
-    data_type,
-    gray_scale=False,
-    output_shape=None,
-    random_flip=False,
-    random_brightness=False,
-    random_contrast=False,
-    random_saturation=False,
-    random_rotate=False,
-    per_image_normalization=True,
-    extension=None,
-):
-    """
-    Parses a single tf.Example into image and label tensors.
-    """
-
-    # Convert the image data from string back to the numbers
-    image = from_filename_to_tensor(filename, extension=extension)
-
-    # Reshape image data into the original shape
-    image = tf.reshape(image, data_shape)
-
-    # Applying image augmentation
-    image = append_image_augmentation(
-        image,
-        gray_scale=gray_scale,
-        output_shape=output_shape,
-        random_flip=random_flip,
-        random_brightness=random_brightness,
-        random_contrast=random_contrast,
-        random_saturation=random_saturation,
-        random_rotate=random_rotate,
-        per_image_normalization=per_image_normalization,
-    )
-
-    label = tf.cast(label, tf.int64)
-    features = dict()
-    features["data"] = image
-    features["key"] = filename
-
-    return features, label
-
-
-def load_pngs(img_path, img_shape):
-    """Read png files using tensorflow API
-    You must know the shape of the image beforehand to use this function.
-
-    Parameters
-    ----------
-    img_path : str
-        Path to the image
-    img_shape : list
-        A list or tuple that contains image's shape in channels_last format
-
-    Returns
-    -------
-    object
-        The loaded png file
-    """
-    img_raw = tf.io.read_file(img_path)
-    img_tensor = tf.image.decode_png(img_raw, channels=img_shape[-1])
-    img_final = tf.reshape(img_tensor, img_shape)
-    return img_final
diff --git a/bob/learn/tensorflow/image/filter.py b/bob/learn/tensorflow/image/filter.py
deleted file mode 100644
index f77fbcd18896c070e10af8daff020a2e34337bab..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/image/filter.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import tensorflow as tf
-
-
-def gaussian_kernel(size: int, mean: float, std: float):
-    """Makes 2D gaussian Kernel for convolution.
-    Code adapted from: https://stackoverflow.com/a/52012658/1286165"""
-
-    d = tf.compat.v1.distributions.Normal(mean, std)
-
-    vals = d.prob(tf.range(start=-size, limit=size + 1, dtype=tf.float32))
-
-    gauss_kernel = tf.einsum("i,j->ij", vals, vals)
-
-    return gauss_kernel / tf.reduce_sum(input_tensor=gauss_kernel)
-
-
-class GaussianFilter:
-    """A class for blurring images"""
-
-    def __init__(self, size=13, mean=0.0, std=3.0, **kwargs):
-        super().__init__(**kwargs)
-        self.size = size
-        self.mean = mean
-        self.std = std
-        self.gauss_kernel = gaussian_kernel(size, mean, std)[:, :, None, None]
-
-    def __call__(self, image):
-        shape = tf.shape(input=image)
-        image = tf.reshape(image, [-1, shape[-3], shape[-2], shape[-1]])
-        input_channels = shape[-1]
-        gauss_kernel = tf.tile(self.gauss_kernel, [1, 1, input_channels, 1])
-        return tf.nn.depthwise_conv2d(
-            input=image,
-            filter=gauss_kernel,
-            strides=[1, 1, 1, 1],
-            padding="SAME",
-            data_format="NHWC",
-        )
diff --git a/bob/learn/tensorflow/losses/BaseLoss.py b/bob/learn/tensorflow/losses/BaseLoss.py
deleted file mode 100644
index f048b88a704a9b3eeebcabaa2df3a33404e757ba..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/BaseLoss.py
+++ /dev/null
@@ -1,106 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-import logging
-
-import tensorflow as tf
-
-logger = logging.getLogger(__name__)
-
-
-# def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True):
-#     """
-#     Simple CrossEntropy loss.
-#     Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
-
-#     **Parameters**
-#       logits:
-#       labels:
-#       add_regularization_losses: Regulize the loss???
-
-#     """
-
-#     with tf.compat.v1.variable_scope('cross_entropy_loss'):
-#         cross_loss = tf.reduce_mean(
-#             input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
-#                 logits=logits, labels=labels),
-#             name="cross_entropy_loss")
-
-#         tf.compat.v1.summary.scalar('cross_entropy_loss', cross_loss)
-#         tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, cross_loss)
-
-#         if add_regularization_losses:
-#             regularization_losses = tf.compat.v1.get_collection(
-#                 tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
-
-#             total_loss = tf.add_n(
-#                 [cross_loss] + regularization_losses, name="total_loss")
-#             return total_loss
-#         else:
-#             return cross_loss
-
-
-def mean_cross_entropy_center_loss(
-    logits, prelogits, labels, n_classes, alpha=0.9, factor=0.01
-):
-    """
-    Implementation of the CrossEntropy + Center Loss from the paper
-    "A Discriminative Feature Learning Approach for Deep Face Recognition"(http://ydwen.github.io/papers/WenECCV16.pdf)
-
-    **Parameters**
-      logits:
-      prelogits:
-      labels:
-      n_classes: Number of classes of your task
-      alpha: Alpha factor ((1-alpha)*centers-prelogits)
-      factor: Weight factor of the center loss
-
-    """
-    # Cross entropy
-    with tf.compat.v1.variable_scope("cross_entropy_loss"):
-        cross_loss = tf.reduce_mean(
-            input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
-                logits=logits, labels=labels
-            ),
-            name="cross_entropy_loss",
-        )
-        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, cross_loss)
-        tf.compat.v1.summary.scalar("loss_cross_entropy", cross_loss)
-
-    # Appending center loss
-    with tf.compat.v1.variable_scope("center_loss"):
-        n_features = prelogits.get_shape()[1]
-
-        centers = tf.compat.v1.get_variable(
-            "centers",
-            [n_classes, n_features],
-            dtype=tf.float32,
-            initializer=tf.compat.v1.constant_initializer(0),
-            trainable=False,
-        )
-
-        # label = tf.reshape(labels, [-1])
-        centers_batch = tf.gather(centers, labels)
-        diff = (1 - alpha) * (centers_batch - prelogits)
-        centers = tf.compat.v1.scatter_sub(centers, labels, diff)
-        center_loss = tf.reduce_mean(input_tensor=tf.square(prelogits - centers_batch))
-        tf.compat.v1.add_to_collection(
-            tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, center_loss * factor
-        )
-        tf.compat.v1.summary.scalar("loss_center", center_loss)
-
-    # Adding the regularizers in the loss
-    with tf.compat.v1.variable_scope("total_loss"):
-        regularization_losses = tf.compat.v1.get_collection(
-            tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES
-        )
-        total_loss = tf.add_n([cross_loss] + regularization_losses, name="total_loss")
-        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, total_loss)
-        tf.compat.v1.summary.scalar("loss_total", total_loss)
-
-    loss = dict()
-    loss["loss"] = total_loss
-    loss["centers"] = centers
-
-    return loss
diff --git a/bob/learn/tensorflow/losses/ContrastiveLoss.py b/bob/learn/tensorflow/losses/ContrastiveLoss.py
deleted file mode 100644
index 416c8e583a560c0855b84eaeebe85e43ae05ce26..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/ContrastiveLoss.py
+++ /dev/null
@@ -1,80 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-import logging
-
-import tensorflow as tf
-
-from bob.learn.tensorflow.utils import compute_euclidean_distance
-
-logger = logging.getLogger(__name__)
-
-
-def contrastive_loss(left_embedding, right_embedding, labels, contrastive_margin=2.0):
-    """
-    Compute the contrastive loss as in
-
-    http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
-
-    :math:`L = 0.5 * (1-Y) * D^2 + 0.5 * (Y) * {max(0, margin - D)}^2`
-
-    where, `0` are assign for pairs from the same class and `1` from pairs from different classes.
-
-
-    **Parameters**
-
-    left_feature:
-      First element of the pair
-
-    right_feature:
-      Second element of the pair
-
-    labels:
-      Label of the pair (0 or 1)
-
-    margin:
-      Contrastive margin
-
-    """
-
-    with tf.compat.v1.name_scope("contrastive_loss"):
-        labels = tf.cast(labels, dtype=tf.float32)
-
-        left_embedding = tf.nn.l2_normalize(left_embedding, 1)
-        right_embedding = tf.nn.l2_normalize(right_embedding, 1)
-
-        d = compute_euclidean_distance(left_embedding, right_embedding)
-
-        with tf.compat.v1.name_scope("within_class"):
-            one = tf.constant(1.0)
-            within_class = tf.multiply(one - labels, tf.square(d))  # (1-Y)*(d^2)
-            within_class_loss = tf.reduce_mean(
-                input_tensor=within_class, name="within_class"
-            )
-            tf.compat.v1.add_to_collection(
-                tf.compat.v1.GraphKeys.LOSSES, within_class_loss
-            )
-
-        with tf.compat.v1.name_scope("between_class"):
-            max_part = tf.square(tf.maximum(contrastive_margin - d, 0))
-            between_class = tf.multiply(
-                labels, max_part
-            )  # (Y) * max((margin - d)^2, 0)
-            between_class_loss = tf.reduce_mean(
-                input_tensor=between_class, name="between_class"
-            )
-            tf.compat.v1.add_to_collection(
-                tf.compat.v1.GraphKeys.LOSSES, between_class_loss
-            )
-
-        with tf.compat.v1.name_scope("total_loss"):
-            loss = 0.5 * (within_class + between_class)
-            loss = tf.reduce_mean(input_tensor=loss, name="contrastive_loss")
-            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, loss)
-
-        tf.compat.v1.summary.scalar("contrastive_loss", loss)
-        tf.compat.v1.summary.scalar("between_class", between_class_loss)
-        tf.compat.v1.summary.scalar("within_class", within_class_loss)
-
-        return loss
diff --git a/bob/learn/tensorflow/losses/StyleLoss.py b/bob/learn/tensorflow/losses/StyleLoss.py
deleted file mode 100644
index 46028aa28517cb282885dbaed8b821b36aba75ac..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/StyleLoss.py
+++ /dev/null
@@ -1,106 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-import functools
-import logging
-
-import tensorflow as tf
-
-logger = logging.getLogger(__name__)
-
-
-def content_loss(noises, content_features):
-    r"""
-
-    Implements the content loss from:
-
-    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).
-
-    For a given noise signal :math:`n`, content image :math:`c` and convolved with the DCNN :math:`\phi` until the layer :math:`l` the content loss is defined as:
-
-    :math:`L(n,c) = \sum_{l=?}^{?}({\phi^l(n) - \phi^l(c)})^2`
-
-
-    Parameters
-    ----------
-
-     noises: :any:`list`
-        A list of tf.Tensor containing all the noises convolved
-
-     content_features: :any:`list`
-        A list of numpy.array containing all the content_features convolved
-
-    """
-
-    content_losses = []
-    for n, c in zip(noises, content_features):
-        content_losses.append((2 * tf.nn.l2_loss(n - c) / c.size))
-    return functools.reduce(tf.add, content_losses)
-
-
-def linear_gram_style_loss(noises, gram_style_features):
-    r"""
-
-    Implements the style loss from:
-
-    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).
-
-    For a given noise signal :math:`n`, content image :math:`c` and convolved with the DCNN :math:`\phi` until the layer :math:`l` the STYLE loss is defined as
-
-    :math:`L(n,c) = \\sum_{l=?}^{?}\\frac{({\phi^l(n)^T*\\phi^l(n) - \\phi^l(c)^T*\\phi^l(c)})^2}{N*M}`
-
-
-    Parameters
-    ----------
-
-     noises: :any:`list`
-        A list of tf.Tensor containing all the noises convolved
-
-     gram_style_features: :any:`list`
-        A list of numpy.array containing all the content_features convolved
-
-    """
-
-    style_losses = []
-    for n, s in zip(noises, gram_style_features):
-        style_losses.append((2 * tf.nn.l2_loss(n - s)) / s.size)
-
-    return functools.reduce(tf.add, style_losses)
-
-
-def denoising_loss(noise):
-    """
-    Computes the denoising loss as in:
-
-    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).
-
-    Parameters
-    ----------
-
-       noise:
-          Input noise
-
-    """
-
-    def _tensor_size(tensor):
-        from operator import mul
-
-        return functools.reduce(mul, (d.value for d in tensor.get_shape()), 1)
-
-    shape = noise.get_shape().as_list()
-
-    noise_y_size = _tensor_size(noise[:, 1:, :, :])
-    noise_x_size = _tensor_size(noise[:, :, 1:, :])
-    denoise_loss = 2 * (
-        (
-            tf.nn.l2_loss(noise[:, 1:, :, :] - noise[:, : shape[1] - 1, :, :])
-            / noise_y_size
-        )
-        + (
-            tf.nn.l2_loss(noise[:, :, 1:, :] - noise[:, :, : shape[2] - 1, :])
-            / noise_x_size
-        )
-    )
-
-    return denoise_loss
diff --git a/bob/learn/tensorflow/losses/TripletLoss.py b/bob/learn/tensorflow/losses/TripletLoss.py
deleted file mode 100644
index eee7529df05dba9b832b84a4081f30c6e4064db6..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/TripletLoss.py
+++ /dev/null
@@ -1,199 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-import logging
-
-import tensorflow as tf
-
-from bob.learn.tensorflow.utils import compute_euclidean_distance
-
-logger = logging.getLogger(__name__)
-
-
-def triplet_loss(anchor_embedding, positive_embedding, negative_embedding, margin=5.0):
-    """
-    Compute the triplet loss as in
-
-    Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
-    "Facenet: A unified embedding for face recognition and clustering."
-    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
-
-    :math:`L  = sum(  |f_a - f_p|^2 - |f_a - f_n|^2  + \lambda)`
-
-    **Parameters**
-
-    left_feature:
-      First element of the pair
-
-    right_feature:
-      Second element of the pair
-
-    label:
-      Label of the pair (0 or 1)
-
-    margin:
-      Contrastive margin
-
-    """
-
-    with tf.compat.v1.name_scope("triplet_loss"):
-        # Normalize
-        anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
-        positive_embedding = tf.nn.l2_normalize(
-            positive_embedding, 1, 1e-10, name="positive"
-        )
-        negative_embedding = tf.nn.l2_normalize(
-            negative_embedding, 1, 1e-10, name="negative"
-        )
-
-        d_positive = tf.reduce_sum(
-            input_tensor=tf.square(tf.subtract(anchor_embedding, positive_embedding)),
-            axis=1,
-        )
-        d_negative = tf.reduce_sum(
-            input_tensor=tf.square(tf.subtract(anchor_embedding, negative_embedding)),
-            axis=1,
-        )
-
-        basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin)
-
-        with tf.compat.v1.name_scope("TripletLoss"):
-            # Between
-            between_class_loss = tf.reduce_mean(input_tensor=d_negative)
-            tf.compat.v1.summary.scalar("loss_between_class", between_class_loss)
-            tf.compat.v1.add_to_collection(
-                tf.compat.v1.GraphKeys.LOSSES, between_class_loss
-            )
-
-            # Within
-            within_class_loss = tf.reduce_mean(input_tensor=d_positive)
-            tf.compat.v1.summary.scalar("loss_within_class", within_class_loss)
-            tf.compat.v1.add_to_collection(
-                tf.compat.v1.GraphKeys.LOSSES, within_class_loss
-            )
-
-            # Total loss
-            loss = tf.reduce_mean(
-                input_tensor=tf.maximum(basic_loss, 0.0), axis=0, name="total_loss"
-            )
-            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, loss)
-            tf.compat.v1.summary.scalar("loss_triplet", loss)
-
-        return loss
-
-
-def triplet_fisher_loss(anchor_embedding, positive_embedding, negative_embedding):
-
-    with tf.compat.v1.name_scope("triplet_loss"):
-        # Normalize
-        anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
-        positive_embedding = tf.nn.l2_normalize(
-            positive_embedding, 1, 1e-10, name="positive"
-        )
-        negative_embedding = tf.nn.l2_normalize(
-            negative_embedding, 1, 1e-10, name="negative"
-        )
-
-        average_class = tf.reduce_mean(input_tensor=anchor_embedding, axis=0)
-        average_total = tf.compat.v1.div(
-            tf.add(
-                tf.reduce_mean(input_tensor=anchor_embedding, axis=0),
-                tf.reduce_mean(input_tensor=negative_embedding, axis=0),
-            ),
-            2,
-        )
-
-        length = anchor_embedding.get_shape().as_list()[0]
-        dim = anchor_embedding.get_shape().as_list()[1]
-        split_positive = tf.unstack(positive_embedding, num=length, axis=0)
-        split_negative = tf.unstack(negative_embedding, num=length, axis=0)
-
-        Sw = None
-        Sb = None
-        for s in zip(split_positive, split_negative):
-            positive = s[0]
-            negative = s[1]
-
-            buffer_sw = tf.reshape(tf.subtract(positive, average_class), shape=(dim, 1))
-            buffer_sw = tf.matmul(buffer_sw, tf.reshape(buffer_sw, shape=(1, dim)))
-
-            buffer_sb = tf.reshape(tf.subtract(negative, average_total), shape=(dim, 1))
-            buffer_sb = tf.matmul(buffer_sb, tf.reshape(buffer_sb, shape=(1, dim)))
-
-            if Sw is None:
-                Sw = buffer_sw
-                Sb = buffer_sb
-            else:
-                Sw = tf.add(Sw, buffer_sw)
-                Sb = tf.add(Sb, buffer_sb)
-
-        # Sw = tf.trace(Sw)
-        # Sb = tf.trace(Sb)
-        # loss = tf.trace(tf.div(Sb, Sw))
-        loss = tf.linalg.trace(
-            tf.compat.v1.div(Sw, Sb), name=tf.compat.v1.GraphKeys.LOSSES
-        )
-
-        return loss, tf.linalg.trace(Sb), tf.linalg.trace(Sw)
-
-
-def triplet_average_loss(
-    anchor_embedding, positive_embedding, negative_embedding, margin=5.0
-):
-    """
-    Compute the triplet loss as in
-
-    Schroff, Florian, Dmitry Kalenichenko, and James Philbin.
-    "Facenet: A unified embedding for face recognition and clustering."
-    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.
-
-    :math:`L  = sum(  |f_a - f_p|^2 - |f_a - f_n|^2  + \lambda)`
-
-    **Parameters**
-
-    left_feature:
-      First element of the pair
-
-    right_feature:
-      Second element of the pair
-
-    label:
-      Label of the pair (0 or 1)
-
-    margin:
-      Contrastive margin
-
-    """
-
-    with tf.compat.v1.name_scope("triplet_loss"):
-        # Normalize
-        anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor")
-        positive_embedding = tf.nn.l2_normalize(
-            positive_embedding, 1, 1e-10, name="positive"
-        )
-        negative_embedding = tf.nn.l2_normalize(
-            negative_embedding, 1, 1e-10, name="negative"
-        )
-
-        anchor_mean = tf.reduce_mean(input_tensor=anchor_embedding, axis=0)
-
-        d_positive = tf.reduce_sum(
-            input_tensor=tf.square(tf.subtract(anchor_mean, positive_embedding)), axis=1
-        )
-        d_negative = tf.reduce_sum(
-            input_tensor=tf.square(tf.subtract(anchor_mean, negative_embedding)), axis=1
-        )
-
-        basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin)
-        loss = tf.reduce_mean(
-            input_tensor=tf.maximum(basic_loss, 0.0),
-            axis=0,
-            name=tf.compat.v1.GraphKeys.LOSSES,
-        )
-
-        return (
-            loss,
-            tf.reduce_mean(input_tensor=d_negative),
-            tf.reduce_mean(input_tensor=d_positive),
-        )
diff --git a/bob/learn/tensorflow/losses/__init__.py b/bob/learn/tensorflow/losses/__init__.py
index 348d773fa1fd3c1f404d2fc2f5aad27fd75c4676..2bfcbd152f92206f5d818bc535833673cefa564f 100644
--- a/bob/learn/tensorflow/losses/__init__.py
+++ b/bob/learn/tensorflow/losses/__init__.py
@@ -1,18 +1,4 @@
-# from .BaseLoss import mean_cross_entropy_loss, mean_cross_entropy_center_loss
-from .center_loss import CenterLoss
-from .ContrastiveLoss import contrastive_loss
-from .mmd import *
-from .pairwise_confusion import total_pairwise_confusion
-from .pixel_wise import PixelWise
-from .StyleLoss import content_loss
-from .StyleLoss import denoising_loss
-from .StyleLoss import linear_gram_style_loss
-from .TripletLoss import triplet_average_loss
-from .TripletLoss import triplet_fisher_loss
-from .TripletLoss import triplet_loss
-from .utils import *
-from .vat import VATLoss
-
+from .center_loss import CenterLoss, CenterLossLayer
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
@@ -31,13 +17,7 @@ def __appropriate__(*args):
 
 
 __appropriate__(
-    # mean_cross_entropy_loss,
-    # mean_cross_entropy_center_loss,
-    contrastive_loss,
-    triplet_loss,
-    triplet_average_loss,
-    triplet_fisher_loss,
-    VATLoss,
-    PixelWise,
+    CenterLoss,
+    CenterLossLayer,
 )
 __all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/bob/learn/tensorflow/losses/center_loss.py b/bob/learn/tensorflow/losses/center_loss.py
index 89b4e0250b402bc89be99c01051e163db335074d..894a461200639042094c6b9e29e1721eee1478cb 100644
--- a/bob/learn/tensorflow/losses/center_loss.py
+++ b/bob/learn/tensorflow/losses/center_loss.py
@@ -1,23 +1,88 @@
 import tensorflow as tf
 
 
-class CenterLoss(tf.keras.losses.Loss):
-    """Center loss."""
+class CenterLossLayer(tf.keras.layers.Layer):
+    """A layer to be added in the model if you want to use CenterLoss
 
-    def __init__(self, n_classes, n_features, alpha=0.9, name="center_loss", **kwargs):
-        super().__init__(name=name, **kwargs)
+    Attributes
+    ----------
+    centers : tf.Variable
+        The variable that keeps track of centers.
+    n_classes : int
+        Number of classes of the task.
+    n_features : int
+        The size of prelogits.
+    """
+
+    def __init__(self, n_classes, n_features, **kwargs):
+        super().__init__(**kwargs)
         self.n_classes = n_classes
         self.n_features = n_features
-        self.alpha = alpha
-
         self.centers = tf.Variable(
-            tf.zeros([n_classes, n_features]), name="centers", trainable=False
+            tf.zeros([n_classes, n_features]),
+            name="centers",
+            trainable=False,
+            # in a distributed strategy, we want updates to this variable to be summed.
+            aggregation=tf.VariableAggregation.SUM,
         )
 
-    def call(self, y_true, y_pred):
-        sparse_labels, prelogits = tf.reshape(y_true, (-1,)), y_pred
+    def call(self, x):
+        # pass through layer
+        return tf.identity(x)
+
+    def get_config(self):
+        config = super().get_config()
+        config.update({"n_classes": self.n_classes, "n_features": self.n_features})
+        return config
+
+
+class CenterLoss(tf.keras.losses.Loss):
+    """Center loss.
+    Introduced in: A Discriminative Feature Learning Approach for Deep Face Recognition
+    https://ydwen.github.io/papers/WenECCV16.pdf
+
+    .. warning::
+
+        This loss MUST NOT BE CALLED during evaluation as it will update the centers!
+        This loss only works with sparse labels.
+        This loss must be used with CenterLossLayer embedded into the model
+
+    Attributes
+    ----------
+    alpha : float
+        The moving average coefficient for updating centers in each batch.
+    centers : tf.Variable
+        The variable that keeps track of centers.
+    centers_layer
+        The layer that keeps track of centers.
+    """
+
+    def __init__(
+        self,
+        centers_layer,
+        alpha=0.9,
+        update_centers=True,
+        name="center_loss",
+        **kwargs
+    ):
+        super().__init__(name=name, **kwargs)
+        self.centers_layer = centers_layer
+        self.centers = self.centers_layer.centers
+        self.alpha = alpha
+        self.update_centers = update_centers
+
+    def call(self, sparse_labels, prelogits):
+        sparse_labels = tf.reshape(sparse_labels, (-1,))
         centers_batch = tf.gather(self.centers, sparse_labels)
-        diff = (1 - self.alpha) * (centers_batch - prelogits)
-        center_loss = tf.reduce_mean(input_tensor=tf.square(prelogits - centers_batch))
-        self.centers.assign(tf.tensor_scatter_nd_sub(self.centers, sparse_labels[:, None], diff))
+        # the reduction of batch dimension will be done by the parent class
+        center_loss = tf.keras.losses.mean_squared_error(prelogits, centers_batch)
+
+        # update centers
+        if self.update_centers:
+            diff = (1 - self.alpha) * (centers_batch - prelogits)
+            updates = tf.scatter_nd(sparse_labels[:, None], diff, self.centers.shape)
+            # using assign_sub will make sure updates are added during distributed
+            # training
+            self.centers.assign_sub(updates)
+
         return center_loss
diff --git a/bob/learn/tensorflow/losses/mmd.py b/bob/learn/tensorflow/losses/mmd.py
deleted file mode 100644
index a48aaba224088d0af63a5c3b5555ed16cae826c7..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/mmd.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import tensorflow as tf
-
-
-def compute_kernel(x, y):
-    """Gaussian kernel."""
-    x_size = tf.shape(input=x)[0]
-    y_size = tf.shape(input=y)[0]
-    dim = tf.shape(input=x)[1]
-    tiled_x = tf.tile(
-        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
-    )
-    tiled_y = tf.tile(
-        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
-    )
-    return tf.exp(
-        -tf.reduce_mean(input_tensor=tf.square(tiled_x - tiled_y), axis=2)
-        / tf.cast(dim, tf.float32)
-    )
-
-
-def mmd(x, y):
-    """Maximum Mean Discrepancy with Gaussian kernel.
-    See: https://stats.stackexchange.com/a/276618/49433
-    """
-    x_kernel = compute_kernel(x, x)
-    y_kernel = compute_kernel(y, y)
-    xy_kernel = compute_kernel(x, y)
-    return (
-        tf.reduce_mean(input_tensor=x_kernel)
-        + tf.reduce_mean(input_tensor=y_kernel)
-        - 2 * tf.reduce_mean(input_tensor=xy_kernel)
-    )
diff --git a/bob/learn/tensorflow/losses/pairwise_confusion.py b/bob/learn/tensorflow/losses/pairwise_confusion.py
deleted file mode 100644
index 6c4c9613b70bea78414beefe842b930dfb608059..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/pairwise_confusion.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import tensorflow as tf
-
-from ..utils import pdist_safe
-from ..utils import upper_triangle
-
-
-def total_pairwise_confusion(prelogits, name=None):
-    """Total Pairwise Confusion Loss
-
-    [1]X. Tu et al., “Learning Generalizable and Identity-Discriminative
-    Representations for Face Anti-Spoofing,” arXiv preprint arXiv:1901.05602, 2019.
-    """
-    # compute L2 norm between all prelogits and sum them.
-    with tf.compat.v1.name_scope(name, default_name="total_pairwise_confusion"):
-        prelogits = tf.reshape(prelogits, (tf.shape(input=prelogits)[0], -1))
-        loss_tpc = tf.reduce_mean(input_tensor=upper_triangle(pdist_safe(prelogits)))
-
-    tf.compat.v1.summary.scalar("loss_tpc", loss_tpc)
-    return loss_tpc
diff --git a/bob/learn/tensorflow/losses/pixel_wise.py b/bob/learn/tensorflow/losses/pixel_wise.py
deleted file mode 100644
index 09cb4c41469596f721bb4d04a5db600603d06ecb..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/pixel_wise.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import tensorflow as tf
-
-from ..dataset import tf_repeat
-from .utils import balanced_sigmoid_cross_entropy_loss_weights
-from .utils import balanced_softmax_cross_entropy_loss_weights
-
-
-class PixelWise:
-    """A pixel wise loss which is just a cross entropy loss but applied to all pixels"""
-
-    def __init__(
-        self, balance_weights=True, n_one_hot_labels=None, label_smoothing=0.5, **kwargs
-    ):
-        super(PixelWise, self).__init__(**kwargs)
-        self.balance_weights = balance_weights
-        self.n_one_hot_labels = n_one_hot_labels
-        self.label_smoothing = label_smoothing
-
-    def __call__(self, labels, logits):
-        with tf.compat.v1.name_scope("PixelWiseLoss"):
-            flatten = tf.keras.layers.Flatten()
-            logits = flatten(logits)
-            n_pixels = logits.get_shape()[-1]
-            weights = 1.0
-            if self.balance_weights and self.n_one_hot_labels:
-                # use labels to figure out the required loss
-                weights = balanced_softmax_cross_entropy_loss_weights(
-                    labels, dtype=logits.dtype
-                )
-                # repeat weights for all pixels
-                weights = tf_repeat(weights[:, None], [1, n_pixels])
-                weights = tf.reshape(weights, (-1,))
-            elif self.balance_weights and not self.n_one_hot_labels:
-                # use labels to figure out the required loss
-                weights = balanced_sigmoid_cross_entropy_loss_weights(
-                    labels, dtype=logits.dtype
-                )
-                # repeat weights for all pixels
-                weights = tf_repeat(weights[:, None], [1, n_pixels])
-
-            if self.n_one_hot_labels:
-                labels = tf_repeat(labels, [n_pixels, 1])
-                labels = tf.reshape(labels, (-1, self.n_one_hot_labels))
-                # reshape logits too as softmax_cross_entropy is buggy and cannot really
-                # handle higher dimensions
-                logits = tf.reshape(logits, (-1, self.n_one_hot_labels))
-                loss_fn = tf.compat.v1.losses.softmax_cross_entropy
-            else:
-                labels = tf.reshape(labels, (-1, 1))
-                labels = tf_repeat(labels, [n_pixels, 1])
-                labels = tf.reshape(labels, (-1, n_pixels))
-                loss_fn = tf.compat.v1.losses.sigmoid_cross_entropy
-
-            loss_pixel_wise = loss_fn(
-                labels,
-                logits=logits,
-                weights=weights,
-                label_smoothing=self.label_smoothing,
-                reduction=tf.compat.v1.losses.Reduction.MEAN,
-            )
-        tf.compat.v1.summary.scalar("loss_pixel_wise", loss_pixel_wise)
-        return loss_pixel_wise
diff --git a/bob/learn/tensorflow/losses/utils.py b/bob/learn/tensorflow/losses/utils.py
deleted file mode 100644
index 013de65aac1136d903f6714ad0d2dfe5e7291f51..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/utils.py
+++ /dev/null
@@ -1,144 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Amir Mohammadi <amir.mohammadi@idiap.ch>
-
-import tensorflow as tf
-
-
-def balanced_softmax_cross_entropy_loss_weights(labels, dtype="float32"):
-    """Computes weights that normalizes your loss per class.
-
-    Labels must be a batch of one-hot encoded labels. The function takes labels and
-    computes the weights per batch. Weights will be smaller for classes that have more
-    samples in this batch. This is useful if you unbalanced classes in your dataset or
-    batch.
-
-    Parameters
-    ----------
-    labels : ``tf.Tensor``
-        Labels of your current input. The shape must be [batch_size, n_classes]. If your
-        labels are not one-hot encoded, you can use ``tf.one_hot`` to convert them first
-        before giving them to this function.
-    dtype : ``tf.dtype``
-        The dtype that weights will have. It should be float. Best is to provide
-        logits.dtype as input.
-
-    Returns
-    -------
-    ``tf.Tensor``
-        Computed weights that will cancel your dataset imbalance per batch.
-
-    Examples
-    --------
-    >>> import numpy
-    >>> import tensorflow as tf
-    >>> from bob.learn.tensorflow.loss import balanced_softmax_cross_entropy_loss_weights
-    >>> labels = numpy.array([[1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [0, 0, 1],
-    ...                 [0, 1, 0],
-    ...                 [0, 0, 1],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [0, 0, 1],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [0, 1, 0],
-    ...                 [1, 0, 0],
-    ...                 [0, 1, 0],
-    ...                 [1, 0, 0],
-    ...                 [0, 0, 1],
-    ...                 [0, 0, 1],
-    ...                 [1, 0, 0],
-    ...                 [0, 0, 1],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [0, 1, 0],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [1, 0, 0],
-    ...                 [0, 1, 0],
-    ...                 [1, 0, 0],
-    ...                 [0, 0, 1],
-    ...                 [1, 0, 0]], dtype="int32")
-    >>> session = tf.Session() # Eager execution is also possible check https://www.tensorflow.org/guide/eager
-    >>> session.run(tf.reduce_sum(labels, axis=0))
-    array([20,  5,  7], dtype=int32)
-    >>> session.run(balanced_softmax_cross_entropy_loss_weights(labels, dtype='float32'))
-    array([0.53333336, 0.53333336, 1.5238096 , 2.1333334 , 1.5238096 ,
-           0.53333336, 0.53333336, 1.5238096 , 0.53333336, 0.53333336,
-           0.53333336, 0.53333336, 0.53333336, 0.53333336, 2.1333334 ,
-           0.53333336, 2.1333334 , 0.53333336, 1.5238096 , 1.5238096 ,
-           0.53333336, 1.5238096 , 0.53333336, 0.53333336, 2.1333334 ,
-           0.53333336, 0.53333336, 0.53333336, 2.1333334 , 0.53333336,
-           1.5238096 , 0.53333336], dtype=float32)
-
-    You would use it like this:
-
-    >>> #weights = balanced_softmax_cross_entropy_loss_weights(labels, dtype=logits.dtype)
-    >>> #loss = tf.losses.softmax_cross_entropy(logits=logits, labels=labels, weights=weights)
-    """
-    shape = tf.cast(tf.shape(input=labels), dtype=dtype)
-    batch_size, n_classes = shape[0], shape[1]
-    weights = tf.cast(tf.reduce_sum(input_tensor=labels, axis=0), dtype=dtype)
-    weights = batch_size / weights / n_classes
-    weights = tf.gather(weights, tf.argmax(input=labels, axis=1))
-    return weights
-
-
-def balanced_sigmoid_cross_entropy_loss_weights(labels, dtype="float32"):
-    """Computes weights that normalizes your loss per class.
-
-    Labels must be a batch of binary labels. The function takes labels and
-    computes the weights per batch. Weights will be smaller for the class that have more
-    samples in this batch. This is useful if you unbalanced classes in your dataset or
-    batch.
-
-    Parameters
-    ----------
-    labels : ``tf.Tensor``
-        Labels of your current input. The shape must be [batch_size] and values must be
-        either 0 or 1.
-    dtype : ``tf.dtype``
-        The dtype that weights will have. It should be float. Best is to provide
-        logits.dtype as input.
-
-    Returns
-    -------
-    ``tf.Tensor``
-        Computed weights that will cancel your dataset imbalance per batch.
-
-    Examples
-    --------
-    >>> import numpy
-    >>> import tensorflow as tf
-    >>> from bob.learn.tensorflow.loss import balanced_sigmoid_cross_entropy_loss_weights
-    >>> labels = numpy.array([1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0,
-    ...                 1, 1, 0, 1, 1, 1, 0, 1, 0, 1], dtype="int32")
-    >>> sum(labels), len(labels)
-    (20, 32)
-    >>> session = tf.Session() # Eager execution is also possible check https://www.tensorflow.org/guide/eager
-    >>> session.run(balanced_sigmoid_cross_entropy_loss_weights(labels, dtype='float32'))
-    array([0.8      , 0.8      , 1.3333334, 1.3333334, 1.3333334, 0.8      ,
-           0.8      , 1.3333334, 0.8      , 0.8      , 0.8      , 0.8      ,
-           0.8      , 0.8      , 1.3333334, 0.8      , 1.3333334, 0.8      ,
-           1.3333334, 1.3333334, 0.8      , 1.3333334, 0.8      , 0.8      ,
-           1.3333334, 0.8      , 0.8      , 0.8      , 1.3333334, 0.8      ,
-           1.3333334, 0.8      ], dtype=float32)
-
-    You would use it like this:
-
-    >>> #weights = balanced_sigmoid_cross_entropy_loss_weights(labels, dtype=logits.dtype)
-    >>> #loss = tf.losses.sigmoid_cross_entropy(logits=logits, labels=labels, weights=weights)
-    """
-    labels = tf.cast(labels, dtype="int32")
-    batch_size = tf.cast(tf.shape(input=labels)[0], dtype=dtype)
-    weights = tf.cast(tf.reduce_sum(input_tensor=labels), dtype=dtype)
-    weights = tf.convert_to_tensor(value=[batch_size - weights, weights])
-    weights = batch_size / weights / 2
-    weights = tf.gather(weights, labels)
-    return weights
diff --git a/bob/learn/tensorflow/losses/vat.py b/bob/learn/tensorflow/losses/vat.py
deleted file mode 100644
index 194d54440e62b990962c4a161b58cd35830f0117..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/losses/vat.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# Adapted from https://github.com/takerum/vat_tf Its license:
-#
-# MIT License
-#
-# Copyright (c) 2017 Takeru Miyato
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-
-from functools import partial
-
-import tensorflow as tf
-
-
-def get_normalized_vector(d):
-    d /= 1e-12 + tf.reduce_max(
-        input_tensor=tf.abs(d), axis=list(range(1, len(d.get_shape()))), keepdims=True
-    )
-    d /= tf.sqrt(
-        1e-6
-        + tf.reduce_sum(
-            input_tensor=tf.pow(d, 2.0),
-            axis=list(range(1, len(d.get_shape()))),
-            keepdims=True,
-        )
-    )
-    return d
-
-
-def logsoftmax(x):
-    xdev = x - tf.reduce_max(input_tensor=x, axis=1, keepdims=True)
-    lsm = xdev - tf.math.log(
-        tf.reduce_sum(input_tensor=tf.exp(xdev), axis=1, keepdims=True)
-    )
-    return lsm
-
-
-def kl_divergence_with_logit(q_logit, p_logit):
-    q = tf.nn.softmax(q_logit)
-    qlogq = tf.reduce_mean(
-        input_tensor=tf.reduce_sum(input_tensor=q * logsoftmax(q_logit), axis=1)
-    )
-    qlogp = tf.reduce_mean(
-        input_tensor=tf.reduce_sum(input_tensor=q * logsoftmax(p_logit), axis=1)
-    )
-    return qlogq - qlogp
-
-
-def entropy_y_x(logit):
-    p = tf.nn.softmax(logit)
-    return -tf.reduce_mean(
-        input_tensor=tf.reduce_sum(input_tensor=p * logsoftmax(logit), axis=1)
-    )
-
-
-class VATLoss:
-    """A class to hold parameters for Virtual Adversarial Training (VAT) Loss
-    and perform it.
-
-    Attributes
-    ----------
-    epsilon : float
-        norm length for (virtual) adversarial training
-    method : str
-        The method for calculating the loss: ``vatent`` for VAT loss + entropy
-        and ``vat`` for only VAT loss.
-    num_power_iterations : int
-        the number of power iterations
-    xi : float
-        small constant for finite difference
-    """
-
-    def __init__(
-        self, epsilon=8.0, xi=1e-6, num_power_iterations=1, method="vatent", **kwargs
-    ):
-        super(VATLoss, self).__init__(**kwargs)
-        self.epsilon = epsilon
-        self.xi = xi
-        self.num_power_iterations = num_power_iterations
-        self.method = method
-
-    def __call__(self, features, logits, architecture, mode):
-        """Computes the VAT loss for unlabeled features.
-        If you are doing semi-supervised learning, only pass the unlabeled
-        features and their logits here.
-
-        Parameters
-        ----------
-        features : object
-            Tensor representing the (unlabeled) features
-        logits : object
-            Tensor representing the logits of (unlabeled) features.
-        architecture : object
-            A callable that constructs the model. It should accept ``mode`` and
-            ``reuse`` as keyword arguments. The features will be given as the
-            first input.
-        mode : str
-            One of tf.estimator.ModeKeys.{TRAIN,EVAL} strings.
-
-        Returns
-        -------
-        object
-            The loss.
-
-        Raises
-        ------
-        NotImplementedError
-            If self.method is not ``vat`` or ``vatent``.
-        """
-        if mode != tf.estimator.ModeKeys.TRAIN:
-            return 0.0
-        architecture = partial(architecture, reuse=True)
-        with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope(), reuse=True):
-            vat_loss = self.virtual_adversarial_loss(
-                features, logits, architecture, mode
-            )
-            tf.compat.v1.summary.scalar("loss_VAT", vat_loss)
-            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, vat_loss)
-            if self.method == "vat":
-                loss = vat_loss
-            elif self.method == "vatent":
-                ent_loss = entropy_y_x(logits)
-                tf.compat.v1.summary.scalar("loss_entropy", ent_loss)
-                tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, ent_loss)
-                loss = vat_loss + ent_loss
-            else:
-                raise ValueError
-            return loss
-
-    def virtual_adversarial_loss(
-        self, features, logits, architecture, mode, name="vat_loss_op"
-    ):
-        r_vadv = self.generate_virtual_adversarial_perturbation(
-            features, logits, architecture, mode
-        )
-        logit_p = tf.stop_gradient(logits)
-        adversarial_input = features + r_vadv
-        tf.compat.v1.summary.image("Adversarial_Image", adversarial_input)
-        logit_m = architecture(adversarial_input, mode=mode)[0]
-        loss = kl_divergence_with_logit(logit_p, logit_m)
-        return tf.identity(loss, name=name)
-
-    def generate_virtual_adversarial_perturbation(
-        self, features, logits, architecture, mode
-    ):
-        d = tf.random.normal(shape=tf.shape(input=features))
-
-        for _ in range(self.num_power_iterations):
-            d = self.xi * get_normalized_vector(d)
-            logit_p = logits
-            logit_m = architecture(features + d, mode=mode)[0]
-            dist = kl_divergence_with_logit(logit_p, logit_m)
-            grad = tf.gradients(ys=dist, xs=[d], aggregation_method=2)[0]
-            d = tf.stop_gradient(grad)
-
-        return self.epsilon * get_normalized_vector(d)
diff --git a/bob/learn/tensorflow/image/__init__.py b/bob/learn/tensorflow/metrics/__init__.py
similarity index 80%
rename from bob/learn/tensorflow/image/__init__.py
rename to bob/learn/tensorflow/metrics/__init__.py
index c8a2e5ae2e1795bd6dc7b66befe78c635d3fb522..55cec2bd10e2e0778b2ea858ab32b5e14ebe1b27 100644
--- a/bob/learn/tensorflow/image/__init__.py
+++ b/bob/learn/tensorflow/metrics/__init__.py
@@ -1,6 +1,4 @@
-from .filter import GaussianFilter
-from .filter import gaussian_kernel
-
+from .embedding_accuracy import EmbeddingAccuracy
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
@@ -13,9 +11,10 @@ def __appropriate__(*args):
     Resolves `Sphinx referencing issues
     <https://github.com/sphinx-doc/sphinx/issues/3048>`
     """
+
     for obj in args:
         obj.__module__ = __name__
 
 
-__appropriate__(GaussianFilter)
+__appropriate__(EmbeddingAccuracy)
 __all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/bob/learn/tensorflow/metrics/embedding_accuracy.py b/bob/learn/tensorflow/metrics/embedding_accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..20ac7294a1000ed09b688d08bbd4be9f1eb35b60
--- /dev/null
+++ b/bob/learn/tensorflow/metrics/embedding_accuracy.py
@@ -0,0 +1,39 @@
+import numpy as np
+import tensorflow as tf
+import tensorflow.keras.backend as K
+from tensorflow.python.keras.metrics import MeanMetricWrapper
+
+from ..utils import pdist
+
+
+def predict_using_tensors(embedding, labels):
+    """
+    Compute the predictions through exhaustive comparisons between
+    embeddings using tensors
+    """
+
+    # Fitting the main diagonal with infs (removing comparisons with the same
+    # sample)
+    inf = tf.cast(tf.ones_like(labels), tf.float32) * np.inf
+
+    distances = pdist(embedding)
+    distances = tf.linalg.set_diag(distances, inf)
+    indexes = tf.argmin(input=distances, axis=1)
+    return tf.gather(labels, indexes)
+
+
+def accuracy_from_embeddings(labels, prelogits):
+    labels = tf.reshape(labels, (-1,))
+    embeddings = tf.nn.l2_normalize(prelogits, 1)
+    predictions = predict_using_tensors(embeddings, labels)
+    return tf.cast(tf.math.equal(labels, predictions), K.floatx())
+
+
+class EmbeddingAccuracy(MeanMetricWrapper):
+    """Calculates accuracy from labels and prelogits.
+    This class relies on the fact that, in each batch, at least two images are
+    available from each class(identity).
+    """
+
+    def __init__(self, name="embedding_accuracy", dtype=None):
+        super().__init__(accuracy_from_embeddings, name, dtype=dtype)
diff --git a/bob/learn/tensorflow/models/__init__.py b/bob/learn/tensorflow/models/__init__.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..d18ceb93bc0a1be07c5d9ffc52688f8695e83717 100644
--- a/bob/learn/tensorflow/models/__init__.py
+++ b/bob/learn/tensorflow/models/__init__.py
@@ -0,0 +1,24 @@
+from .alexnet import AlexNet_simplified
+from .densenet import DenseNet
+
+# gets sphinx autodoc done right - don't remove it
+def __appropriate__(*args):
+    """Says object was actually declared here, an not on the import module.
+
+    Parameters:
+
+      *args: An iterable of objects to modify
+
+    Resolves `Sphinx referencing issues
+    <https://github.com/sphinx-doc/sphinx/issues/3048>`
+    """
+
+    for obj in args:
+        obj.__module__ = __name__
+
+
+__appropriate__(
+    AlexNet_simplified,
+    DenseNet,
+)
+__all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/bob/learn/tensorflow/models/inception_resnet_v2.py b/bob/learn/tensorflow/models/inception_resnet_v2.py
index bfab700a721e539a428a4339312555e574523b47..ad5b8629f52b27496aa5c92aeaa2e3793965042a 100644
--- a/bob/learn/tensorflow/models/inception_resnet_v2.py
+++ b/bob/learn/tensorflow/models/inception_resnet_v2.py
@@ -15,14 +15,25 @@ from tensorflow.keras.layers import Dropout
 from tensorflow.keras.layers import GlobalAvgPool2D
 from tensorflow.keras.layers import GlobalMaxPool2D
 from tensorflow.keras.layers import Input
-from tensorflow.keras.layers import Lambda
 from tensorflow.keras.layers import MaxPool2D
 from tensorflow.keras.models import Model
+from tensorflow.keras.models import Sequential
+
+from ..utils import SequentialLayer
 
 logger = logging.getLogger(__name__)
 
 
-class Conv2D_BN(tf.keras.Sequential):
+def Conv2D_BN(
+    filters,
+    kernel_size,
+    strides=1,
+    padding="same",
+    activation="relu",
+    use_bias=False,
+    name=None,
+    **kwargs,
+):
     """Utility class to apply conv + BN.
 
     # Arguments
@@ -53,49 +64,28 @@ class Conv2D_BN(tf.keras.Sequential):
         and `name + '/BatchNorm'` for the batch norm layer.
     """
 
-    def __init__(
-        self,
-        filters,
-        kernel_size,
-        strides=1,
-        padding="same",
-        activation="relu",
-        use_bias=False,
-        name=None,
-        **kwargs,
-    ):
-
-        self.filters = filters
-        self.kernel_size = kernel_size
-        self.strides = strides
-        self.padding = padding
-        self.activation = activation
-        self.use_bias = use_bias
-
-        layers = [
-            Conv2D(
-                filters,
-                kernel_size,
-                strides=strides,
-                padding=padding,
-                use_bias=use_bias,
-                name=name,
-            )
-        ]
+    layers = [
+        Conv2D(
+            filters,
+            kernel_size,
+            strides=strides,
+            padding=padding,
+            use_bias=use_bias,
+            name="Conv2D",
+        )
+    ]
 
-        if not use_bias:
-            bn_axis = 1 if K.image_data_format() == "channels_first" else 3
-            bn_name = None if name is None else name + "/BatchNorm"
-            layers += [BatchNormalization(axis=bn_axis, scale=False, name=bn_name)]
+    if not use_bias:
+        bn_axis = 1 if K.image_data_format() == "channels_first" else 3
+        layers += [BatchNormalization(axis=bn_axis, scale=False, name="BatchNorm")]
 
-        if activation is not None:
-            ac_name = None if name is None else name + "/Act"
-            layers += [Activation(activation, name=ac_name)]
+    if activation is not None:
+        layers += [Activation(activation, name="Act")]
 
-        super().__init__(layers, name=name, **kwargs)
+    return SequentialLayer(layers, name=name, **kwargs)
 
 
-class ScaledResidual(tf.keras.Model):
+class ScaledResidual(tf.keras.layers.Layer):
     """A scaled residual connection layer"""
 
     def __init__(self, scale, name="scaled_residual", **kwargs):
@@ -105,8 +95,13 @@ class ScaledResidual(tf.keras.Model):
     def call(self, inputs, training=None):
         return inputs[0] + inputs[1] * self.scale
 
+    def get_config(self):
+        config = super().get_config()
+        config.update({"scale": self.scale, "name": self.name})
+        return config
+
 
-class InceptionResnetBlock(tf.keras.Model):
+class InceptionResnetBlock(tf.keras.layers.Layer):
     """An Inception-ResNet block.
 
     This class builds 3 types of Inception-ResNet blocks mentioned
@@ -164,24 +159,24 @@ class InceptionResnetBlock(tf.keras.Model):
         self.n = n
 
         if block_type == "block35":
-            branch_0 = [Conv2D_BN(32 // n, 1, name="branch0_conv1")]
-            branch_1 = [Conv2D_BN(32 // n, 1, name="branch1_conv1")]
-            branch_1 += [Conv2D_BN(32 // n, 3, name="branch1_conv2")]
-            branch_2 = [Conv2D_BN(32 // n, 1, name="branch2_conv1")]
-            branch_2 += [Conv2D_BN(48 // n, 3, name="branch2_conv2")]
-            branch_2 += [Conv2D_BN(64 // n, 3, name="branch2_conv3")]
+            branch_0 = [Conv2D_BN(32 // n, 1, name="Branch_0/Conv2d_1x1")]
+            branch_1 = [Conv2D_BN(32 // n, 1, name="Branch_1/Conv2d_0a_1x1")]
+            branch_1 += [Conv2D_BN(32 // n, 3, name="Branch_1/Conv2d_0b_3x3")]
+            branch_2 = [Conv2D_BN(32 // n, 1, name="Branch_2/Conv2d_0a_1x1")]
+            branch_2 += [Conv2D_BN(48 // n, 3, name="Branch_2/Conv2d_0b_3x3")]
+            branch_2 += [Conv2D_BN(64 // n, 3, name="Branch_2/Conv2d_0c_3x3")]
             branches = [branch_0, branch_1, branch_2]
         elif block_type == "block17":
-            branch_0 = [Conv2D_BN(192 // n, 1, name="branch0_conv1")]
-            branch_1 = [Conv2D_BN(128 // n, 1, name="branch1_conv1")]
-            branch_1 += [Conv2D_BN(160 // n, (1, 7), name="branch1_conv2")]
-            branch_1 += [Conv2D_BN(192 // n, (7, 1), name="branch1_conv3")]
+            branch_0 = [Conv2D_BN(192 // n, 1, name="Branch_0/Conv2d_1x1")]
+            branch_1 = [Conv2D_BN(128 // n, 1, name="Branch_1/Conv2d_0a_1x1")]
+            branch_1 += [Conv2D_BN(160 // n, (1, 7), name="Branch_1/Conv2d_0b_1x7")]
+            branch_1 += [Conv2D_BN(192 // n, (7, 1), name="Branch_1/Conv2d_0c_7x1")]
             branches = [branch_0, branch_1]
         elif block_type == "block8":
-            branch_0 = [Conv2D_BN(192 // n, 1, name="branch0_conv1")]
-            branch_1 = [Conv2D_BN(192 // n, 1, name="branch1_conv1")]
-            branch_1 += [Conv2D_BN(224 // n, (1, 3), name="branch1_conv2")]
-            branch_1 += [Conv2D_BN(256 // n, (3, 1), name="branch1_conv3")]
+            branch_0 = [Conv2D_BN(192 // n, 1, name="Branch_0/Conv2d_1x1")]
+            branch_1 = [Conv2D_BN(192 // n, 1, name="Branch_1/Conv2d_0a_1x1")]
+            branch_1 += [Conv2D_BN(224 // n, (1, 3), name="Branch_1/Conv2d_0b_1x3")]
+            branch_1 += [Conv2D_BN(256 // n, (3, 1), name="Branch_1/Conv2d_0c_3x1")]
             branches = [branch_0, branch_1]
         else:
             raise ValueError(
@@ -195,18 +190,9 @@ class InceptionResnetBlock(tf.keras.Model):
         channel_axis = 1 if K.image_data_format() == "channels_first" else 3
         self.concat = Concatenate(axis=channel_axis, name="concatenate")
         self.up_conv = Conv2D_BN(
-            n_channels, 1, activation=None, use_bias=True, name="up_conv"
+            n_channels, 1, activation=None, use_bias=True, name="Conv2d_1x1"
         )
 
-        # output_shape = (None, None, n_channels)
-        # if K.image_data_format() == "channels_first":
-        #     output_shape = (n_channels, None, None)
-        # self.residual = Lambda(
-        #     lambda inputs, scale: inputs[0] + inputs[1] * scale,
-        #     output_shape=output_shape,
-        #     arguments={"scale": scale},
-        #     name="residual_scale",
-        # )
         self.residual = ScaledResidual(scale)
         self.act = lambda x: x
         if activation is not None:
@@ -228,8 +214,26 @@ class InceptionResnetBlock(tf.keras.Model):
 
         return x
 
+    def get_config(self):
+        config = super().get_config()
+        config.update(
+            {
+                name: getattr(self, name)
+                for name in [
+                    "n_channels",
+                    "scale",
+                    "block_type",
+                    "block_idx",
+                    "activation",
+                    "n",
+                    "name",
+                ]
+            }
+        )
+        return config
+
 
-class ReductionA(tf.keras.Model):
+class ReductionA(tf.keras.layers.Layer):
     """A Reduction A block for InceptionResnetV2"""
 
     def __init__(
@@ -257,19 +261,19 @@ class ReductionA(tf.keras.Model):
                 3,
                 strides=1 if use_atrous else 2,
                 padding=padding,
-                name="branch1_conv1",
+                name="Branch_0/Conv2d_1a_3x3",
             )
         ]
 
         branch_2 = [
-            Conv2D_BN(k, 1, name="branch2_conv1"),
-            Conv2D_BN(kl, 3, name="branch2_conv2"),
+            Conv2D_BN(k, 1, name="Branch_1/Conv2d_0a_1x1"),
+            Conv2D_BN(kl, 3, name="Branch_1/Conv2d_0b_3x3"),
             Conv2D_BN(
                 km,
                 3,
                 strides=1 if use_atrous else 2,
                 padding=padding,
-                name="branch2_conv3",
+                name="Branch_1/Conv2d_1a_3x3",
             ),
         ]
 
@@ -278,7 +282,7 @@ class ReductionA(tf.keras.Model):
                 3,
                 strides=1 if use_atrous else 2,
                 padding=padding,
-                name="branch3_pool1",
+                name="Branch_2/MaxPool_1a_3x3",
             )
         ]
         self.branches = [branch_1, branch_2, branch_pool]
@@ -298,8 +302,18 @@ class ReductionA(tf.keras.Model):
 
         return self.concat(branch_outputs)
 
+    def get_config(self):
+        config = super().get_config()
+        config.update(
+            {
+                name: getattr(self, name)
+                for name in ["padding", "k", "kl", "km", "n", "use_atrous", "name"]
+            }
+        )
+        return config
+
 
-class ReductionB(tf.keras.Model):
+class ReductionB(tf.keras.layers.Layer):
     """A Reduction B block for InceptionResnetV2"""
 
     def __init__(
@@ -326,22 +340,24 @@ class ReductionB(tf.keras.Model):
         self.pq = pq
 
         branch_1 = [
-            Conv2D_BN(n, 1, name="branch1_conv1"),
-            Conv2D_BN(no, 3, strides=2, padding=padding, name="branch1_conv2"),
+            Conv2D_BN(n, 1, name="Branch_0/Conv2d_0a_1x1"),
+            Conv2D_BN(no, 3, strides=2, padding=padding, name="Branch_0/Conv2d_1a_3x3"),
         ]
 
         branch_2 = [
-            Conv2D_BN(p, 1, name="branch2_conv1"),
-            Conv2D_BN(pq, 3, strides=2, padding=padding, name="branch2_conv2"),
+            Conv2D_BN(p, 1, name="Branch_1/Conv2d_0a_1x1"),
+            Conv2D_BN(pq, 3, strides=2, padding=padding, name="Branch_1/Conv2d_1a_3x3"),
         ]
 
         branch_3 = [
-            Conv2D_BN(k, 1, name="branch3_conv1"),
-            Conv2D_BN(kl, 3, name="branch3_conv2"),
-            Conv2D_BN(km, 3, strides=2, padding=padding, name="branch3_conv3"),
+            Conv2D_BN(k, 1, name="Branch_2/Conv2d_0a_1x1"),
+            Conv2D_BN(kl, 3, name="Branch_2/Conv2d_0b_3x3"),
+            Conv2D_BN(km, 3, strides=2, padding=padding, name="Branch_2/Conv2d_1a_3x3"),
         ]
 
-        branch_pool = [MaxPool2D(3, strides=2, padding=padding, name=f"branch4_pool1")]
+        branch_pool = [
+            MaxPool2D(3, strides=2, padding=padding, name="Branch_3/MaxPool_1a_3x3")
+        ]
         self.branches = [branch_1, branch_2, branch_3, branch_pool]
         channel_axis = 1 if K.image_data_format() == "channels_first" else 3
         self.concat = Concatenate(axis=channel_axis, name=f"{name}/mixed")
@@ -359,38 +375,48 @@ class ReductionB(tf.keras.Model):
 
         return self.concat(branch_outputs)
 
+    def get_config(self):
+        config = super().get_config()
+        config.update(
+            {
+                name: getattr(self, name)
+                for name in ["padding", "k", "kl", "km", "n", "no", "p", "pq", "name"]
+            }
+        )
+        return config
+
 
-class InceptionA(tf.keras.Model):
+class InceptionA(tf.keras.layers.Layer):
     def __init__(self, pool_filters, name="inception_a", **kwargs):
         super().__init__(name=name, **kwargs)
         self.pool_filters = pool_filters
 
         self.branch1x1 = Conv2D_BN(
-            96, kernel_size=1, padding="same", name="branch1_conv1"
+            96, kernel_size=1, padding="same", name="Branch_0/Conv2d_1x1"
         )
 
         self.branch3x3dbl_1 = Conv2D_BN(
-            64, kernel_size=1, padding="same", name="branch2_conv1"
+            64, kernel_size=1, padding="same", name="Branch_2/Conv2d_0a_1x1"
         )
         self.branch3x3dbl_2 = Conv2D_BN(
-            96, kernel_size=3, padding="same", name="branch2_conv2"
+            96, kernel_size=3, padding="same", name="Branch_2/Conv2d_0b_3x3"
         )
         self.branch3x3dbl_3 = Conv2D_BN(
-            96, kernel_size=3, padding="same", name="branch2_conv3"
+            96, kernel_size=3, padding="same", name="Branch_2/Conv2d_0c_3x3"
         )
 
         self.branch5x5_1 = Conv2D_BN(
-            48, kernel_size=1, padding="same", name="branch3_conv1"
+            48, kernel_size=1, padding="same", name="Branch_1/Conv2d_0a_1x1"
         )
         self.branch5x5_2 = Conv2D_BN(
-            64, kernel_size=5, padding="same", name="branch3_conv2"
+            64, kernel_size=5, padding="same", name="Branch_1/Conv2d_0b_5x5"
         )
 
         self.branch_pool_1 = AvgPool2D(
-            pool_size=3, strides=1, padding="same", name="branch4_pool1"
+            pool_size=3, strides=1, padding="same", name="Branch_3/AvgPool_0a_3x3"
         )
         self.branch_pool_2 = Conv2D_BN(
-            pool_filters, kernel_size=1, padding="same", name="branch4_conv1"
+            pool_filters, kernel_size=1, padding="same", name="Branch_3/Conv2d_0b_1x1"
         )
 
         channel_axis = 1 if K.image_data_format() == "channels_first" else 3
@@ -412,6 +438,11 @@ class InceptionA(tf.keras.Model):
         outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
         return self.concat(outputs)
 
+    def get_config(self):
+        config = super().get_config()
+        config.update({"pool_filters": self.pool_filters, "name": self.name})
+        return config
+
 
 def InceptionResNetV2(
     include_top=True,
@@ -419,18 +450,17 @@ def InceptionResNetV2(
     input_shape=None,
     pooling=None,
     classes=1000,
+    bottleneck=False,
+    dropout_rate=0.2,
+    name="InceptionResnetV2",
     **kwargs,
 ):
     """Instantiates the Inception-ResNet v2 architecture.
-    Optionally loads weights pre-trained on ImageNet.
     Note that the data format convention used by the model is
     the one specified in your Keras config at `~/.keras/keras.json`.
     # Arguments
         include_top: whether to include the fully-connected
             layer at the top of the network.
-        weights: one of `None` (random initialization),
-              'imagenet' (pre-training on ImageNet),
-              or the path to the weights file to be loaded.
         input_tensor: optional Keras tensor (i.e. output of `tf.keras.Input()`)
             to use as image input for the model.
         input_shape: optional shape tuple, only to be specified
@@ -467,97 +497,118 @@ def InceptionResNetV2(
         else:
             img_input = input_tensor
 
-    # Stem block: 35 x 35 x 192
-    x = Conv2D_BN(32, 3, strides=2, padding="valid")(img_input)
-    x = Conv2D_BN(32, 3, padding="valid")(x)
-    x = Conv2D_BN(64, 3)(x)
-    x = MaxPool2D(3, strides=2)(x)
-    x = Conv2D_BN(80, 1, padding="valid")(x)
-    x = Conv2D_BN(192, 3, padding="valid")(x)
-    x = MaxPool2D(3, strides=2)(x)
-
-    # Mixed 5b (Inception-A block): 35 x 35 x 320
-    # branch_0 = Conv2D_BN(96, 1)(x)
-    # branch_1 = Conv2D_BN(48, 1)(x)
-    # branch_1 = Conv2D_BN(64, 5)(branch_1)
-    # branch_2 = Conv2D_BN(64, 1)(x)
-    # branch_2 = Conv2D_BN(96, 3)(branch_2)
-    # branch_2 = Conv2D_BN(96, 3)(branch_2)
-    # branch_pool = AvgPool2D(3, strides=1, padding="same")(x)
-    # branch_pool = Conv2D_BN(64, 1)(branch_pool)
-    # branches = [branch_0, branch_1, branch_2, branch_pool]
-    # channel_axis = 1 if K.image_data_format() == "channels_first" else 3
-    # x = Concatenate(axis=channel_axis, name="mixed_5b")(branches)
-    x = InceptionA(pool_filters=64)(x)
+    layers = [
+        # Stem block: 35 x 35 x 192
+        Conv2D_BN(32, 3, strides=2, padding="valid", name="Conv2d_1a_3x3"),
+        Conv2D_BN(32, 3, padding="valid", name="Conv2d_2a_3x3"),
+        Conv2D_BN(64, 3, name="Conv2d_2b_3x3"),
+        MaxPool2D(3, strides=2, name="MaxPool_3a_3x3"),
+        Conv2D_BN(80, 1, padding="valid", name="Conv2d_3b_1x1"),
+        Conv2D_BN(192, 3, padding="valid", name="Conv2d_4a_3x3"),
+        MaxPool2D(3, strides=2, name="MaxPool_5a_3x3"),
+        # Mixed 5b (Inception-A block): 35 x 35 x 320
+        InceptionA(pool_filters=64, name="Mixed_5b"),
+    ]
 
     # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
     for block_idx in range(1, 11):
-        x = InceptionResnetBlock(
-            n_channels=320,
-            scale=0.17,
-            block_type="block35",
-            block_idx=block_idx,
-            name=f"block35_{block_idx}",
-        )(x)
+        layers.append(
+            InceptionResnetBlock(
+                n_channels=320,
+                scale=0.17,
+                block_type="block35",
+                block_idx=block_idx,
+                name=f"block35_{block_idx}",
+            )
+        )
 
     # Mixed 6a (Reduction-A block): 17 x 17 x 1088
-    x = ReductionA(padding="valid", n=384, k=256, kl=256, km=384, use_atrous=False)(x)
+    layers.append(
+        ReductionA(
+            padding="valid",
+            n=384,
+            k=256,
+            kl=256,
+            km=384,
+            use_atrous=False,
+            name="Mixed_6a",
+        )
+    )
 
     # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
     for block_idx in range(1, 21):
-        x = InceptionResnetBlock(
-            n_channels=1088,
-            scale=0.1,
-            block_type="block17",
-            block_idx=block_idx,
-            name=f"block17_{block_idx}",
-        )(x)
+        layers.append(
+            InceptionResnetBlock(
+                n_channels=1088,
+                scale=0.1,
+                block_type="block17",
+                block_idx=block_idx,
+                name=f"block17_{block_idx}",
+            )
+        )
 
     # Mixed 7a (Reduction-B block): 8 x 8 x 2080
-    x = ReductionB(
-        padding="valid", n=256, no=384, p=256, pq=288, k=256, kl=288, km=320
-    )(x)
+    layers.append(
+        ReductionB(
+            padding="valid",
+            n=256,
+            no=384,
+            p=256,
+            pq=288,
+            k=256,
+            kl=288,
+            km=320,
+            name="Mixed_7a",
+        )
+    )
 
     # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
     for block_idx in range(1, 10):
-        x = InceptionResnetBlock(
+        layers.append(
+            InceptionResnetBlock(
+                n_channels=2080,
+                scale=0.2,
+                block_type="block8",
+                block_idx=block_idx,
+                name=f"block8_{block_idx}",
+            )
+        )
+    layers.append(
+        InceptionResnetBlock(
             n_channels=2080,
-            scale=0.2,
+            scale=1.0,
+            activation=None,
             block_type="block8",
-            block_idx=block_idx,
-            name=f"block8_{block_idx}",
-        )(x)
-    x = InceptionResnetBlock(
-        n_channels=2080,
-        scale=1.0,
-        activation=None,
-        block_type="block8",
-        block_idx=10,
-        name=f"block8_{block_idx+1}",
-    )(x)
+            block_idx=10,
+            name=f"block8_{block_idx+1}",
+        )
+    )
 
     # Final convolution block: 8 x 8 x 1536
-    x = Conv2D_BN(1536, 1, name="conv_7b")(x)
+    layers.append(Conv2D_BN(1536, 1, name="Conv2d_7b_1x1"))
 
-    if include_top:
-        # Classification block
-        x = GlobalAvgPool2D(name="avg_pool")(x)
-        x = Dense(classes, name="predictions")(x)
-    else:
-        if pooling == "avg":
-            x = GlobalAvgPool2D()(x)
-        elif pooling == "max":
-            x = GlobalMaxPool2D()(x)
+    if (include_top and pooling is None) or (bottleneck):
+        pooling = "avg"
 
-    # Ensure that the model takes into account
-    # any potential predecessors of `input_tensor`.
-    if input_tensor is not None:
-        inputs = tf.keras.utils.get_source_inputs(input_tensor)
-    else:
-        inputs = img_input
+    if pooling == "avg":
+        layers.append(GlobalAvgPool2D())
+    elif pooling == "max":
+        layers.append(GlobalMaxPool2D())
 
-    # Create model.
-    model = Model(inputs, x, name="inception_resnet_v2")
+    if bottleneck:
+        layers.append(Dropout(dropout_rate, name="Dropout"))
+        layers.append(Dense(128, use_bias=False, name="Bottleneck"))
+        layers.append(
+            BatchNormalization(axis=-1, scale=False, name="Bottleneck/BatchNorm")
+        )
+
+    # Classification block
+    if include_top:
+        layers.append(Dense(classes, name="logits"))
+
+    # Create model and call it on input to create its variables.
+    model = Sequential(layers, name=name, **kwargs)
+    model(img_input)
 
     return model
 
@@ -694,9 +745,8 @@ def MultiScaleInceptionResNetV2(
 
 if __name__ == "__main__":
     import pkg_resources
-    from tabulate import tabulate
-
     from bob.learn.tensorflow.utils import model_summary
+    from tabulate import tabulate
 
     def print_model(inputs, outputs, name=None):
         print("")
diff --git a/bob/learn/tensorflow/script/__init__.py b/bob/learn/tensorflow/scripts/__init__.py
similarity index 100%
rename from bob/learn/tensorflow/script/__init__.py
rename to bob/learn/tensorflow/scripts/__init__.py
diff --git a/bob/learn/tensorflow/script/datasets_to_tfrecords.py b/bob/learn/tensorflow/scripts/datasets_to_tfrecords.py
similarity index 88%
rename from bob/learn/tensorflow/script/datasets_to_tfrecords.py
rename to bob/learn/tensorflow/scripts/datasets_to_tfrecords.py
index a252a5f6504756c1c2d489dd51d224de116dbf02..1b4917bc14c26e51d5bd33cedb086326bc07d150 100644
--- a/bob/learn/tensorflow/script/datasets_to_tfrecords.py
+++ b/bob/learn/tensorflow/scripts/datasets_to_tfrecords.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
 """Convert datasets to TFRecords
 """
 from __future__ import absolute_import
@@ -6,17 +5,13 @@ from __future__ import division
 from __future__ import print_function
 
 import logging
-import os
 
 import click
-import tensorflow as tf
 
 from bob.extension.scripts.click_helper import ConfigCommand
 from bob.extension.scripts.click_helper import ResourceOption
-from bob.extension.scripts.click_helper import log_parameters
 from bob.extension.scripts.click_helper import verbosity_option
-from bob.learn.tensorflow.dataset.tfrecords import dataset_to_tfrecord
-from bob.learn.tensorflow.dataset.tfrecords import tfrecord_name_and_json_name
+
 
 logger = logging.getLogger(__name__)
 
@@ -50,6 +45,11 @@ def datasets_to_tfrecords(dataset, output, force, **kwargs):
     To use this script with SGE, change your dataset and output based on the SGE_TASK_ID
     environment variable in your config file.
     """
+    from bob.extension.scripts.click_helper import log_parameters
+    import os
+    from bob.learn.tensorflow.data.tfrecords import dataset_to_tfrecord
+    from bob.learn.tensorflow.data.tfrecords import tfrecord_name_and_json_name
+
     log_parameters(logger)
 
     output, json_output = tfrecord_name_and_json_name(output)
diff --git a/bob/learn/tensorflow/script/fit.py b/bob/learn/tensorflow/scripts/fit.py
similarity index 93%
rename from bob/learn/tensorflow/script/fit.py
rename to bob/learn/tensorflow/scripts/fit.py
index a655bfcae9cfc42fc115680fd8bd7b516ed0e6c0..bb3085924b3d9c7f5ff768c06f22d03f6d3b505a 100644
--- a/bob/learn/tensorflow/script/fit.py
+++ b/bob/learn/tensorflow/scripts/fit.py
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
 
 @click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
 @click.option(
-    "--model",
+    "--model-fn",
     "-m",
     required=True,
     cls=ResourceOption,
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
 @click.option(
     "--epochs",
     "-e",
-    default=1,
+    default=10,
     type=click.types.INT,
     cls=ResourceOption,
     help="Number of epochs to train model. See " "tf.keras.Model.fit.",
@@ -66,13 +66,6 @@ logger = logging.getLogger(__name__)
 @click.option(
     "--class-weight", "-c", cls=ResourceOption, help="See tf.keras.Model.fit."
 )
-@click.option(
-    "--initial-epoch",
-    default=0,
-    type=click.types.INT,
-    cls=ResourceOption,
-    help="See tf.keras.Model.fit.",
-)
 @click.option(
     "--steps-per-epoch",
     type=click.types.INT,
@@ -87,14 +80,13 @@ logger = logging.getLogger(__name__)
 )
 @verbosity_option(cls=ResourceOption)
 def fit(
-    model,
+    model_fn,
     train_input_fn,
     epochs,
     verbose,
     callbacks,
     eval_input_fn,
     class_weight,
-    initial_epoch,
     steps_per_epoch,
     validation_steps,
     **kwargs
@@ -110,6 +102,8 @@ def fit(
     if save_callback:
         model_dir = save_callback[0].filepath
         logger.info("Training a model in %s", model_dir)
+    model = model_fn()
+
     history = model.fit(
         x=train_input_fn(),
         epochs=epochs,
@@ -117,7 +111,6 @@ def fit(
         callbacks=list(callbacks) if callbacks else None,
         validation_data=None if eval_input_fn is None else eval_input_fn(),
         class_weight=class_weight,
-        initial_epoch=initial_epoch,
         steps_per_epoch=steps_per_epoch,
         validation_steps=validation_steps,
     )
diff --git a/bob/learn/tensorflow/script/keras.py b/bob/learn/tensorflow/scripts/keras.py
similarity index 100%
rename from bob/learn/tensorflow/script/keras.py
rename to bob/learn/tensorflow/scripts/keras.py
diff --git a/bob/learn/tensorflow/script/tf.py b/bob/learn/tensorflow/scripts/tf.py
similarity index 100%
rename from bob/learn/tensorflow/script/tf.py
rename to bob/learn/tensorflow/scripts/tf.py
diff --git a/bob/learn/tensorflow/test/data/db_to_tfrecords_config.py b/bob/learn/tensorflow/test/data/db_to_tfrecords_config.py
deleted file mode 100644
index 27c48dc292262f4f57bf2d4718d638ff89702dd6..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/data/db_to_tfrecords_config.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from bob.learn.tensorflow.dataset.generator import dataset_using_generator
-
-groups = ["dev"]
-
-samples = database.all_files(groups=groups)
-
-CLIENT_IDS = (str(f.client_id) for f in database.all_files(groups=groups))
-CLIENT_IDS = list(set(CLIENT_IDS))
-CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
-
-
-def file_to_label(f):
-    return CLIENT_IDS[str(f.client_id)]
-
-
-def reader(biofile):
-    data = read_original_data(
-        biofile, database.original_directory, database.original_extension
-    )
-    label = file_to_label(biofile)
-    key = str(biofile.path).encode("utf-8")
-    return (data, label, key)
-
-
-dataset = dataset_using_generator(samples, reader)
-datasets = [dataset]
diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0.png b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0.png
deleted file mode 100644
index 52d39487637b8a7ba460c93ecc9e1bb92e5ca42f..0000000000000000000000000000000000000000
Binary files a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0.png and /dev/null differ
diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0_GRAY.png b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0_GRAY.png
deleted file mode 100644
index e7de9b7d4b792351e2724ada32bd88d9dc5d3ff0..0000000000000000000000000000000000000000
Binary files a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0_GRAY.png and /dev/null differ
diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p02_i0_0.png b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p02_i0_0.png
deleted file mode 100644
index 0c7e298de460379d02de275c38ebc24840a258fa..0000000000000000000000000000000000000000
Binary files a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p02_i0_0.png and /dev/null differ
diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m304_01_p01_i0_0.png b/bob/learn/tensorflow/test/data/dummy_image_database/m304_01_p01_i0_0.png
deleted file mode 100644
index 53c25af50711c607d2d05cb9566acfe2b140977d..0000000000000000000000000000000000000000
Binary files a/bob/learn/tensorflow/test/data/dummy_image_database/m304_01_p01_i0_0.png and /dev/null differ
diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m304_02_f12_i0_0.png b/bob/learn/tensorflow/test/data/dummy_image_database/m304_02_f12_i0_0.png
deleted file mode 100644
index 0fdf6b4d8fa118657bddd7b2b219d96085180d74..0000000000000000000000000000000000000000
Binary files a/bob/learn/tensorflow/test/data/dummy_image_database/m304_02_f12_i0_0.png and /dev/null differ
diff --git a/bob/learn/tensorflow/test/data/input_tfrecords_config.py b/bob/learn/tensorflow/test/data/input_tfrecords_config.py
deleted file mode 100644
index 86f856c231afac34a7b6204e9cad1e985acad071..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/data/input_tfrecords_config.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import tensorflow as tf
-
-from bob.learn.tensorflow.dataset.tfrecords import batch_data_and_labels
-from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels
-
-tfrecord_filenames = ["%(tfrecord_filenames)s"]
-data_shape = (1, 112, 92)  # size of atnt images
-data_type = tf.uint8
-batch_size = 2
-epochs = 2
-
-
-def train_input_fn():
-    return shuffle_data_and_labels(
-        tfrecord_filenames, data_shape, data_type, batch_size, epochs=epochs
-    )
-
-
-def eval_input_fn():
-    return batch_data_and_labels(
-        tfrecord_filenames, data_shape, data_type, batch_size, epochs=1
-    )
-
-
-# config for train_and_evaluate
-train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=200)
-eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
diff --git a/bob/learn/tensorflow/test/data/mnist_estimator.py b/bob/learn/tensorflow/test/data/mnist_estimator.py
deleted file mode 100644
index 378b52628da8c6f1a1bbbdf29ff8eaacd508216b..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/data/mnist_estimator.py
+++ /dev/null
@@ -1,6 +0,0 @@
-import tensorflow as tf
-
-data = tf.feature_column.numeric_column("data", shape=[784])
-estimator = tf.estimator.LinearClassifier(
-    feature_columns=[data], n_classes=10, loss_reduction=tf.keras.losses.Reduction.SUM
-)
diff --git a/bob/learn/tensorflow/test/data/mnist_input_fn.py b/bob/learn/tensorflow/test/data/mnist_input_fn.py
deleted file mode 100644
index cdc3e4333085d653c98524c03e8c33f94e448e5f..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/data/mnist_input_fn.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import tensorflow as tf
-
-from bob.db.mnist import Database
-
-database = Database()
-
-
-def input_fn(mode):
-    if mode == tf.estimator.ModeKeys.TRAIN:
-        groups = "train"
-        num_epochs = None
-        shuffle = True
-    else:
-        groups = "test"
-        num_epochs = 1
-        shuffle = True
-    data, labels = database.data(groups=groups)
-    return tf.compat.v1.estimator.inputs.numpy_input_fn(
-        x={"data": data.astype("float32"), "key": labels.astype("float32")},
-        y=labels.astype("int32"),
-        batch_size=128,
-        num_epochs=num_epochs,
-        shuffle=shuffle,
-    )
-
-
-train_input_fn = input_fn(tf.estimator.ModeKeys.TRAIN)
-eval_input_fn = input_fn(tf.estimator.ModeKeys.EVAL)
diff --git a/bob/learn/tensorflow/test/test_dataset.py b/bob/learn/tensorflow/test/test_dataset.py
deleted file mode 100644
index f70ee9fbfa4fcb2e82424d86d463f5f79302bfa3..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/test_dataset.py
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-import numpy
-import pkg_resources
-import tensorflow as tf
-
-from bob.learn.tensorflow.dataset.generator import dataset_using_generator
-
-data_shape = (250, 250, 3)
-output_shape = (50, 50)
-data_type = tf.float32
-batch_size = 2
-validation_batch_size = 250
-epochs = 1
-
-# Trainer logits
-filenames = [
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m301_01_p01_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m301_01_p02_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m301_01_p01_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m301_01_p02_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m301_01_p01_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m301_01_p02_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m304_01_p01_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m304_02_f12_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m304_01_p01_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m304_02_f12_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m304_01_p01_i0_0.png"
-    ),
-    pkg_resources.resource_filename(
-        __name__, "data/dummy_image_database/m304_02_f12_i0_0.png"
-    ),
-]
-labels = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
-
-
-def test_dataset_using_generator():
-    def reader(f):
-        key = 0
-        label = 0
-        yield {"data": f, "key": key}, label
-
-    shape = (2, 2, 1)
-    samples = [numpy.ones(shape, dtype="float32") * i for i in range(10)]
-
-    with tf.compat.v1.Session() as session:
-        dataset = dataset_using_generator(samples, reader, multiple_samples=True)
-        iterator = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
-        for i in range(11):
-            try:
-                sample = session.run(iterator)
-                assert sample[0]["data"].shape == shape
-                assert numpy.allclose(sample[0]["data"], samples[i])
-            except tf.errors.OutOfRangeError:
-                break
diff --git a/bob/learn/tensorflow/test/test_datasets_to_tfrecords.py b/bob/learn/tensorflow/test/test_datasets_to_tfrecords.py
deleted file mode 100644
index 751856219cf9dec968416f4c95c590742c4d835b..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/test_datasets_to_tfrecords.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import os
-import shutil
-import tempfile
-
-import numpy as np
-import pkg_resources
-import tensorflow as tf
-from click.testing import CliRunner
-
-from bob.extension.config import load
-from bob.extension.scripts.click_helper import assert_click_runner_result
-from bob.io.base import create_directories_safe
-from bob.learn.tensorflow.dataset.tfrecords import dataset_from_tfrecord
-from bob.learn.tensorflow.script.db_to_tfrecords import datasets_to_tfrecords
-from bob.learn.tensorflow.utils import create_mnist_tfrecord
-from bob.learn.tensorflow.utils import load_mnist
-
-regenerate_reference = False
-
-dummy_config = pkg_resources.resource_filename(
-    "bob.learn.tensorflow", "test/data/db_to_tfrecords_config.py"
-)
-
-
-def compare_datasets(ds1, ds2, sess=None):
-    if tf.executing_eagerly():
-        for values1, values2 in zip(ds1, ds2):
-            values1 = tf.nest.flatten(values1)
-            values2 = tf.nest.flatten(values2)
-            for v1, v2 in zip(values1, values2):
-                if not tf.reduce_all(input_tensor=tf.math.equal(v1, v2)):
-                    return False
-    else:
-        ds1 = tf.compat.v1.data.make_one_shot_iterator(ds1).get_next()
-        ds2 = tf.compat.v1.data.make_one_shot_iterator(ds2).get_next()
-        while True:
-            try:
-                values1, values2 = sess.run([ds1, ds2])
-            except tf.errors.OutOfRangeError:
-                break
-            values1 = tf.nest.flatten(values1)
-            values2 = tf.nest.flatten(values2)
-            for v1, v2 in zip(values1, values2):
-                v1, v2 = np.asarray(v1), np.asarray(v2)
-                if not np.all(v1 == v2):
-                    return False
-    return True
-
-
-def test_datasets_to_tfrecords():
-    runner = CliRunner()
-    with runner.isolated_filesystem():
-        output_path = "./test"
-        args = (dummy_config, "--output", output_path)
-        result = runner.invoke(datasets_to_tfrecords, args=args, standalone_mode=False)
-        assert_click_runner_result(result)
-        # read back the tfrecod
-        with tf.compat.v1.Session() as sess:
-            dataset2 = dataset_from_tfrecord(output_path)
-            dataset1 = load(
-                [dummy_config], attribute_name="dataset", entry_point_group="bob"
-            )
-            assert compare_datasets(dataset1, dataset2, sess)
diff --git a/bob/learn/tensorflow/test/test_loss.py b/bob/learn/tensorflow/test/test_loss.py
deleted file mode 100644
index 98612ac8d3ba3b0f8c474a9eee71bf1869c93b41..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/test_loss.py
+++ /dev/null
@@ -1,177 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-import numpy
-import tensorflow as tf
-
-from bob.learn.tensorflow.loss import balanced_sigmoid_cross_entropy_loss_weights
-from bob.learn.tensorflow.loss import balanced_softmax_cross_entropy_loss_weights
-
-
-def test_balanced_softmax_cross_entropy_loss_weights():
-    labels = numpy.array(
-        [
-            [1, 0, 0],
-            [1, 0, 0],
-            [0, 0, 1],
-            [0, 1, 0],
-            [0, 0, 1],
-            [1, 0, 0],
-            [1, 0, 0],
-            [0, 0, 1],
-            [1, 0, 0],
-            [1, 0, 0],
-            [1, 0, 0],
-            [1, 0, 0],
-            [1, 0, 0],
-            [1, 0, 0],
-            [0, 1, 0],
-            [1, 0, 0],
-            [0, 1, 0],
-            [1, 0, 0],
-            [0, 0, 1],
-            [0, 0, 1],
-            [1, 0, 0],
-            [0, 0, 1],
-            [1, 0, 0],
-            [1, 0, 0],
-            [0, 1, 0],
-            [1, 0, 0],
-            [1, 0, 0],
-            [1, 0, 0],
-            [0, 1, 0],
-            [1, 0, 0],
-            [0, 0, 1],
-            [1, 0, 0],
-        ],
-        dtype="int32",
-    )
-
-    with tf.compat.v1.Session() as session:
-        weights = session.run(balanced_softmax_cross_entropy_loss_weights(labels))
-
-    expected_weights = numpy.array(
-        [
-            0.53333336,
-            0.53333336,
-            1.5238096,
-            2.1333334,
-            1.5238096,
-            0.53333336,
-            0.53333336,
-            1.5238096,
-            0.53333336,
-            0.53333336,
-            0.53333336,
-            0.53333336,
-            0.53333336,
-            0.53333336,
-            2.1333334,
-            0.53333336,
-            2.1333334,
-            0.53333336,
-            1.5238096,
-            1.5238096,
-            0.53333336,
-            1.5238096,
-            0.53333336,
-            0.53333336,
-            2.1333334,
-            0.53333336,
-            0.53333336,
-            0.53333336,
-            2.1333334,
-            0.53333336,
-            1.5238096,
-            0.53333336,
-        ],
-        dtype="float32",
-    )
-
-    assert numpy.allclose(weights, expected_weights)
-
-
-def test_balanced_sigmoid_cross_entropy_loss_weights():
-    labels = numpy.array(
-        [
-            1,
-            1,
-            0,
-            0,
-            0,
-            1,
-            1,
-            0,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            0,
-            1,
-            0,
-            1,
-            0,
-            0,
-            1,
-            0,
-            1,
-            1,
-            0,
-            1,
-            1,
-            1,
-            0,
-            1,
-            0,
-            1,
-        ],
-        dtype="int32",
-    )
-
-    with tf.compat.v1.Session() as session:
-        weights = session.run(
-            balanced_sigmoid_cross_entropy_loss_weights(labels, dtype="float32")
-        )
-
-    expected_weights = numpy.array(
-        [
-            0.8,
-            0.8,
-            1.3333334,
-            1.3333334,
-            1.3333334,
-            0.8,
-            0.8,
-            1.3333334,
-            0.8,
-            0.8,
-            0.8,
-            0.8,
-            0.8,
-            0.8,
-            1.3333334,
-            0.8,
-            1.3333334,
-            0.8,
-            1.3333334,
-            1.3333334,
-            0.8,
-            1.3333334,
-            0.8,
-            0.8,
-            1.3333334,
-            0.8,
-            0.8,
-            0.8,
-            1.3333334,
-            0.8,
-            1.3333334,
-            0.8,
-        ],
-        dtype="float32",
-    )
-
-    assert numpy.allclose(weights, expected_weights)
diff --git a/bob/learn/tensorflow/test/test_utils.py b/bob/learn/tensorflow/test/test_utils.py
deleted file mode 100644
index 49ba3ca3c6d7402ab4b7c38eac560767504ef0b0..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/test/test_utils.py
+++ /dev/null
@@ -1,61 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-
-import numpy
-import tensorflow as tf
-
-from bob.learn.tensorflow.utils import compute_embedding_accuracy
-from bob.learn.tensorflow.utils import compute_embedding_accuracy_tensors
-
-
-"""
-Some unit tests for the datashuffler
-"""
-
-
-def test_embedding_accuracy():
-
-    numpy.random.seed(10)
-    samples_per_class = 5
-
-    class_a = numpy.random.normal(loc=0, scale=0.1, size=(samples_per_class, 2))
-    labels_a = numpy.zeros(samples_per_class)
-
-    class_b = numpy.random.normal(loc=10, scale=0.1, size=(samples_per_class, 2))
-    labels_b = numpy.ones(samples_per_class)
-
-    data = numpy.vstack((class_a, class_b))
-    labels = numpy.concatenate((labels_a, labels_b))
-
-    assert compute_embedding_accuracy(data, labels) == 1.0
-
-    # Adding noise
-    noise = numpy.random.normal(loc=0, scale=0.1, size=(samples_per_class, 2))
-    noise_labels = numpy.ones(samples_per_class)
-
-    data = numpy.vstack((data, noise))
-    labels = numpy.concatenate((labels, noise_labels))
-
-    assert compute_embedding_accuracy(data, labels) == 10 / 15.0
-
-
-def test_embedding_accuracy_tensors():
-
-    numpy.random.seed(10)
-    samples_per_class = 5
-
-    class_a = numpy.random.normal(loc=0, scale=0.1, size=(samples_per_class, 2))
-    labels_a = numpy.zeros(samples_per_class)
-
-    class_b = numpy.random.normal(loc=10, scale=0.1, size=(samples_per_class, 2))
-    labels_b = numpy.ones(samples_per_class)
-
-    data = numpy.vstack((class_a, class_b))
-    labels = numpy.concatenate((labels_a, labels_b))
-
-    data = tf.convert_to_tensor(value=data.astype("float32"))
-    labels = tf.convert_to_tensor(value=labels.astype("int64"))
-
-    accuracy = compute_embedding_accuracy_tensors(data, labels)
-    assert accuracy == 1.0
diff --git a/bob/learn/tensorflow/test/__init__.py b/bob/learn/tensorflow/tests/__init__.py
similarity index 100%
rename from bob/learn/tensorflow/test/__init__.py
rename to bob/learn/tensorflow/tests/__init__.py
diff --git a/bob/learn/tensorflow/tests/data/db_to_tfrecords_config.py b/bob/learn/tensorflow/tests/data/db_to_tfrecords_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..52799ddd8bf9cd2c1cbf2f52bab0419c90678d62
--- /dev/null
+++ b/bob/learn/tensorflow/tests/data/db_to_tfrecords_config.py
@@ -0,0 +1,17 @@
+import tensorflow as tf
+from bob.learn.tensorflow.data import dataset_using_generator
+
+mnist = tf.keras.datasets.mnist
+
+(x_train, y_train), (_, _) = mnist.load_data()
+samples = (tf.keras.backend.arange(len(x_train)), x_train, y_train)
+
+
+def reader(sample):
+    data = sample[1]
+    label = sample[2]
+    key = str(sample[0]).encode("utf-8")
+    return ({"data": data, "key": key}, label)
+
+
+dataset = dataset_using_generator(samples, reader)
diff --git a/bob/learn/tensorflow/tests/test_dataset.py b/bob/learn/tensorflow/tests/test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b5d46fa91ac00ad7ac064da962bb729f106eb94
--- /dev/null
+++ b/bob/learn/tensorflow/tests/test_dataset.py
@@ -0,0 +1,18 @@
+import numpy as np
+
+from bob.learn.tensorflow.data import dataset_using_generator
+
+
+def test_dataset_using_generator():
+    def reader(f):
+        key = 0
+        label = 0
+        yield {"data": f, "key": key}, label
+
+    shape = (2, 2, 1)
+    samples = [np.ones(shape, dtype="float32") * i for i in range(10)]
+
+    dataset = dataset_using_generator(samples, reader, multiple_samples=True)
+    for i, sample in enumerate(dataset):
+        assert sample[0]["data"].shape == shape
+        assert np.allclose(sample[0]["data"], samples[i])
diff --git a/bob/learn/tensorflow/tests/test_datasets_to_tfrecords.py b/bob/learn/tensorflow/tests/test_datasets_to_tfrecords.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b6548abab4d14e5b308cb73a96fe7ddacc89d25
--- /dev/null
+++ b/bob/learn/tensorflow/tests/test_datasets_to_tfrecords.py
@@ -0,0 +1,38 @@
+import pkg_resources
+import tensorflow as tf
+from bob.extension.config import load
+from bob.extension.scripts.click_helper import assert_click_runner_result
+from bob.learn.tensorflow.data.tfrecords import dataset_from_tfrecord
+from bob.learn.tensorflow.scripts.datasets_to_tfrecords import datasets_to_tfrecords
+from click.testing import CliRunner
+
+regenerate_reference = False
+
+dummy_config = pkg_resources.resource_filename(
+    "bob.learn.tensorflow", "tests/data/db_to_tfrecords_config.py"
+)
+
+
+def compare_datasets(ds1, ds2):
+    for values1, values2 in zip(ds1, ds2):
+        values1 = tf.nest.flatten(values1)
+        values2 = tf.nest.flatten(values2)
+        for v1, v2 in zip(values1, values2):
+            if not tf.reduce_all(input_tensor=tf.math.equal(v1, v2)):
+                return False
+    return True
+
+
+def test_datasets_to_tfrecords():
+    runner = CliRunner()
+    with runner.isolated_filesystem():
+        output_path = "./test"
+        args = (dummy_config, "--output", output_path)
+        result = runner.invoke(datasets_to_tfrecords, args=args, standalone_mode=False)
+        assert_click_runner_result(result)
+        # read back the tfrecod
+        dataset2 = dataset_from_tfrecord(output_path)
+        dataset1 = load(
+            [dummy_config], attribute_name="dataset", entry_point_group="bob"
+        )
+        assert compare_datasets(dataset1, dataset2)
diff --git a/bob/learn/tensorflow/tests/test_utils.py b/bob/learn/tensorflow/tests/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ae3aee142ff30fd3bb4ccde5d9aadce8542c76
--- /dev/null
+++ b/bob/learn/tensorflow/tests/test_utils.py
@@ -0,0 +1,26 @@
+import numpy
+import tensorflow as tf
+
+from bob.learn.tensorflow.metrics import EmbeddingAccuracy
+
+
+def test_embedding_accuracy_tensors():
+
+    numpy.random.seed(10)
+    samples_per_class = 5
+    m = EmbeddingAccuracy()
+
+    class_a = numpy.random.normal(loc=0, scale=0.1, size=(samples_per_class, 2))
+    labels_a = numpy.zeros(samples_per_class)
+
+    class_b = numpy.random.normal(loc=10, scale=0.1, size=(samples_per_class, 2))
+    labels_b = numpy.ones(samples_per_class)
+
+    data = numpy.vstack((class_a, class_b))
+    labels = numpy.concatenate((labels_a, labels_b))
+
+    data = tf.convert_to_tensor(value=data.astype("float32"))
+    labels = tf.convert_to_tensor(value=labels.astype("int64"))
+    m(labels, data)
+
+    assert m.result() == 1.0
diff --git a/bob/learn/tensorflow/utils/__init__.py b/bob/learn/tensorflow/utils/__init__.py
index ebfdb2b1296b54a05d35d826d98f0d14aae606ad..444a481681f2f06bb9bb06377db7c492e3f201fc 100644
--- a/bob/learn/tensorflow/utils/__init__.py
+++ b/bob/learn/tensorflow/utils/__init__.py
@@ -1,9 +1,3 @@
-from .eval import *
-from .graph import *
 from .keras import *
 from .math import *
-from .reproducible import *
-from .session import Session
-from .singleton import Singleton
-from .train import *
-from .util import *
+from .image import *
diff --git a/bob/learn/tensorflow/utils/eval.py b/bob/learn/tensorflow/utils/eval.py
deleted file mode 100644
index d647b0a7ad880ed5e665bfe87da85ee2463aa063..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/utils/eval.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-
-
-def get_global_step(path):
-    """Returns the global number associated with the model checkpoint path. The
-    checkpoint must have been saved with the
-    :any:`tf.train.MonitoredTrainingSession`.
-
-    Parameters
-    ----------
-    path : str
-        The path to model checkpoint, usually ckpt.model_checkpoint_path
-
-    Returns
-    -------
-    global_step : int
-        The global step number.
-    """
-    checkpoint_reader = tf.compat.v1.train.NewCheckpointReader(path)
-    return checkpoint_reader.get_tensor(tf.compat.v1.GraphKeys.GLOBAL_STEP)
diff --git a/bob/learn/tensorflow/utils/graph.py b/bob/learn/tensorflow/utils/graph.py
deleted file mode 100644
index 0a5fd5904180659bd92bf6e71ca5a7ddefd213f9..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/utils/graph.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import tensorflow as tf
-
-
-def call_on_frozen_graph(
-    graph_def_path, input, return_elements, input_name, name=None, **kwargs
-):
-    """Loads a frozen graph def file (.pb) and replaces its input with the given input
-    and return the requested output tensors.
-
-    Parameters
-    ----------
-    graph_def_path : str
-        Path to the graph definition file
-    input : object
-        Input tensor
-    return_elements : [str]
-        A list of strings which corresponds to operations in the graph.
-    input_name : str, optional
-        The name of input in the graph that will be replaced by input.
-    name : str, optional
-        The scope of the imported operations. Defaults to "import".
-    **kwargs
-        Extra arguments to be passed to tf.import_graph_def
-
-    Returns
-    -------
-    list
-        List of requested operations. Normally you would use
-        ``returned_operations[0].outputs[0]``
-    """
-    with tf.io.gfile.GFile(graph_def_path, "rb") as f:
-        graph_def = tf.compat.v1.GraphDef()
-        graph_def.ParseFromString(f.read())
-    input_map = {input_name: input}
-
-    return tf.import_graph_def(
-        graph_def,
-        input_map=input_map,
-        return_elements=return_elements,
-        name=name,
-        **kwargs
-    )
diff --git a/bob/learn/tensorflow/utils/image.py b/bob/learn/tensorflow/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..b05b2509e52319d09b3aa1529377fed58c8ac6d8
--- /dev/null
+++ b/bob/learn/tensorflow/utils/image.py
@@ -0,0 +1,63 @@
+import tensorflow as tf
+
+
+def to_channels_last(image):
+    """Converts the image to channel_last format. This is the same format as in
+    matplotlib, skimage, and etc.
+
+    Parameters
+    ----------
+    image : `tf.Tensor`
+        At least a 3 dimensional image. If the dimension is more than 3, the
+        last 3 dimensions are assumed to be [C, H, W].
+
+    Returns
+    -------
+    image : `tf.Tensor`
+        The image in [..., H, W, C] format.
+
+    Raises
+    ------
+    ValueError
+        If dim of image is less than 3.
+    """
+    ndim = image.ndim
+    if ndim < 3:
+        raise ValueError(
+            "The image needs to be at least 3 dimensional but it " "was {}".format(ndim)
+        )
+    axis_order = [1, 2, 0]
+    shift = ndim - 3
+    axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
+    return tf.transpose(a=image, perm=axis_order)
+
+
+def to_channels_first(image):
+    """Converts the image to channel_first format. This is the same format as
+    in bob.io.image and bob.io.video.
+
+    Parameters
+    ----------
+    image : `tf.Tensor`
+        At least a 3 dimensional image. If the dimension is more than 3, the
+        last 3 dimensions are assumed to be [H, W, C].
+
+    Returns
+    -------
+    image : `tf.Tensor`
+        The image in [..., C, H, W] format.
+
+    Raises
+    ------
+    ValueError
+        If dim of image is less than 3.
+    """
+    ndim = image.ndim
+    if ndim < 3:
+        raise ValueError(
+            "The image needs to be at least 3 dimensional but it " "was {}".format(ndim)
+        )
+    axis_order = [2, 0, 1]
+    shift = ndim - 3
+    axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
+    return tf.transpose(a=image, perm=axis_order)
diff --git a/bob/learn/tensorflow/utils/keras.py b/bob/learn/tensorflow/utils/keras.py
index 3a278fd6ded75ed8a7a05a5c47f98c594744ca69..6fceb5a4192266b663fe2e1957040c2571bf45c4 100644
--- a/bob/learn/tensorflow/utils/keras.py
+++ b/bob/learn/tensorflow/utils/keras.py
@@ -1,36 +1,81 @@
+import copy
 import logging
 
 import tensorflow as tf
 import tensorflow.keras.backend as K
+from tensorflow.python.keras import layers as layer_module
+from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
 
 logger = logging.getLogger(__name__)
 
+SINGLE_LAYER_OUTPUT_ERROR_MSG = (
+    "All layers in a Sequential model should have "
+    "a single output tensor. For multi-output "
+    "layers, use the functional API."
+)
 
-def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN):
-    """
-    Check if a variable is trainable or not
+
+class SequentialLayer(tf.keras.layers.Layer):
+    """A Layer that does the same thing as tf.keras.Sequential but
+    its variables can be scoped.
 
     Parameters
     ----------
-
-    name: str
-       Layer name
-
-    trainable_variables: list
-       List containing the variables or scopes to be trained.
-       If None, the variable/scope is trained
+    layers : list
+        List of layers. All layers must be provided at initialization time
     """
 
-    # if mode is not training, so we shutdown
-    if mode != tf.estimator.ModeKeys.TRAIN:
-        return False
-
-    # If None, we train by default
-    if trainable_variables is None:
-        return True
-
-    # Here is my choice to shutdown the whole scope
-    return name in trainable_variables
+    def __init__(self, layers, **kwargs):
+        super().__init__(**kwargs)
+        self.sequential_layers = list(layers)
+
+    def call(self, inputs, training=None, mask=None):
+        outputs = inputs
+        for layer in self.sequential_layers:
+            # During each iteration, `inputs` are the inputs to `layer`, and `outputs`
+            # are the outputs of `layer` applied to `inputs`. At the end of each
+            # iteration `inputs` is set to `outputs` to prepare for the next layer.
+            kwargs = {}
+            argspec = tf_inspect.getfullargspec(layer.call).args
+            if "mask" in argspec:
+                kwargs["mask"] = mask
+            if "training" in argspec:
+                kwargs["training"] = training
+
+            outputs = layer(outputs, **kwargs)
+
+            if len(nest.flatten(outputs)) != 1:
+                raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
+
+            mask = getattr(outputs, "_keras_mask", None)
+
+        return outputs
+
+    def get_config(self):
+        layer_configs = []
+        for layer in self.sequential_layers:
+            layer_configs.append(generic_utils.serialize_keras_object(layer))
+        config = {"name": self.name, "layers": copy.deepcopy(layer_configs)}
+        return config
+
+    @classmethod
+    def from_config(cls, config, custom_objects=None):
+        if "name" in config:
+            name = config["name"]
+            layer_configs = config["layers"]
+        else:
+            name = None
+            layer_configs = config
+        layers = []
+        for layer_config in layer_configs:
+            layer = layer_module.deserialize(
+                layer_config, custom_objects=custom_objects
+            )
+            layers.append(layer)
+        model = cls(layers, name=name)
+        return model
 
 
 def keras_channels_index():
@@ -48,7 +93,7 @@ def keras_model_weights_as_initializers_for_variables(model):
     model : object
         A Keras model.
     """
-    sess = K.get_session()
+    sess = tf.compat.v1.keras.backend.get_session()
     n = len(model.variables)
     logger.debug("Initializing %d variables with their current weights", n)
     for variable in model.variables:
@@ -58,25 +103,6 @@ def keras_model_weights_as_initializers_for_variables(model):
         variable._initial_value = initial_value
 
 
-def apply_trainable_variables_on_keras_model(model, trainable_variables, mode):
-    """Changes the trainable status of layers in a keras model.
-    It can only turn off the trainable status of layer.
-
-    Parameters
-    ----------
-    model : object
-        A Keras model
-    trainable_variables : list or None
-        See bob.learn.tensorflow.estimators.Logits
-    mode : str
-        One of tf.estimator.ModeKeys
-    """
-    for layer in model.layers:
-        trainable = is_trainable(layer.name, trainable_variables, mode=mode)
-        if layer.trainable:
-            layer.trainable = trainable
-
-
 def _create_var_map(variables, normalizer=None):
     if normalizer is None:
 
@@ -107,11 +133,7 @@ def initialize_model_from_checkpoint(model, checkpoint, normalizer=None):
 
 
 def model_summary(model, do_print=False):
-    try:
-        from tensorflow.python.keras.utils.layer_utils import count_params
-    except ImportError:
-        from tensorflow_core.python.keras.utils.layer_utils import count_params
-    nest = tf.nest
+    from tensorflow.keras.backend import count_params
 
     if model.__class__.__name__ == "Sequential":
         sequential_like = True
diff --git a/bob/learn/tensorflow/utils/math.py b/bob/learn/tensorflow/utils/math.py
index b79b4496224958af5eeebb8151dff2e7c18f7ed4..304d3d71fef9d361b4c2080344568745263a6884 100644
--- a/bob/learn/tensorflow/utils/math.py
+++ b/bob/learn/tensorflow/utils/math.py
@@ -75,3 +75,65 @@ def upper_triangle(A):
     mask = tf.cast(mask_a - mask_b, dtype=tf.bool)
     upper_triangular_flat = tf.boolean_mask(tensor=A, mask=mask)
     return upper_triangular_flat
+
+
+def pdist(A, metric="sqeuclidean"):
+    if metric != "sqeuclidean":
+        raise NotImplementedError()
+    r = tf.reduce_sum(input_tensor=A * A, axis=1)
+    r = tf.reshape(r, [-1, 1])
+    D = r - 2 * tf.matmul(A, A, transpose_b=True) + tf.transpose(a=r)
+    return D
+
+
+def cdist(A, B, metric="sqeuclidean"):
+    if metric != "sqeuclidean":
+        raise NotImplementedError()
+    M1, M2 = tf.shape(input=A)[0], tf.shape(input=B)[0]
+    # code from https://stackoverflow.com/a/43839605/1286165
+    p1 = tf.matmul(
+        tf.expand_dims(tf.reduce_sum(input_tensor=tf.square(A), axis=1), 1),
+        tf.ones(shape=(1, M2)),
+    )
+    p2 = tf.transpose(
+        a=tf.matmul(
+            tf.reshape(tf.reduce_sum(input_tensor=tf.square(B), axis=1), shape=[-1, 1]),
+            tf.ones(shape=(M1, 1)),
+            transpose_b=True,
+        )
+    )
+
+    D = tf.add(p1, p2) - 2 * tf.matmul(A, B, transpose_b=True)
+    return D
+
+
+def random_choice_no_replacement(one_dim_input, num_indices_to_drop=3, sort=False):
+    """Similar to np.random.choice with no replacement.
+    Code from https://stackoverflow.com/a/54755281/1286165
+    """
+    input_length = tf.shape(input=one_dim_input)[0]
+
+    # create uniform distribution over the sequence
+    uniform_distribution = tf.random.uniform(
+        shape=[input_length],
+        minval=0,
+        maxval=None,
+        dtype=tf.float32,
+        seed=None,
+        name=None,
+    )
+
+    # grab the indices of the greatest num_words_to_drop values from the distibution
+    _, indices_to_keep = tf.nn.top_k(
+        uniform_distribution, input_length - num_indices_to_drop
+    )
+
+    # sort the indices
+    if sort:
+        sorted_indices_to_keep = tf.sort(indices_to_keep)
+    else:
+        sorted_indices_to_keep = indices_to_keep
+
+    # gather indices from the input array using the filtered actual array
+    result = tf.gather(one_dim_input, sorted_indices_to_keep)
+    return result
diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py
deleted file mode 100644
index bcb988a4486ef8fe5df99e95e0022ba99ded306b..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/utils/reproducible.py
+++ /dev/null
@@ -1,96 +0,0 @@
-"""Helps training reproducible networks.
-"""
-import os
-import random as rn
-
-import numpy as np
-import tensorflow as tf
-from tensorflow.core.protobuf import rewriter_config_pb2
-
-
-def set_seed(
-    seed=0,
-    python_hash_seed=0,
-    log_device_placement=False,
-    allow_soft_placement=False,
-    arithmetic_optimization=None,
-    allow_growth=None,
-    memory_optimization=None,
-):
-    """Sets the seeds in python, numpy, and tensorflow in order to help
-    training reproducible networks.
-
-    Parameters
-    ----------
-    seed : :obj:`int`, optional
-        The seed to set.
-    python_hash_seed : :obj:`int`, optional
-        https://docs.python.org/3.4/using/cmdline.html#envvar-PYTHONHASHSEED
-    log_device_placement : :obj:`bool`, optional
-        Optionally, log device placement of tensorflow variables.
-
-    Returns
-    -------
-    :any:`tf.ConfigProto`
-        Session config.
-    :any:`tf.estimator.RunConfig`
-        A run config to help training estimators.
-
-    Notes
-    -----
-        This functions return a list and its length might change. Please use
-        indices to select one of returned values. For example
-        ``sess_config, run_config = set_seed()[:2]``.
-    """
-    # reproducible networks
-    # The below is necessary in Python 3.2.3 onwards to
-    # have reproducible behavior for certain hash-based operations.
-    # See these references for further details:
-    # https://docs.python.org/3.4/using/cmdline.html#envvar-PYTHONHASHSEED
-    # https://github.com/fchollet/keras/issues/2280#issuecomment-306959926
-    os.environ["PYTHONHASHSEED"] = "{}".format(python_hash_seed)
-
-    # The below is necessary for starting Numpy generated random numbers
-    # in a well-defined initial state.
-    np.random.seed(seed)
-
-    # The below is necessary for starting core Python generated random numbers
-    # in a well-defined state.
-    rn.seed(seed)
-
-    # Force TensorFlow to use single thread.
-    # Multiple threads are a potential source of
-    # non-reproducible results.
-    # For further details, see:
-    # https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
-    session_config = tf.compat.v1.ConfigProto(
-        intra_op_parallelism_threads=1,
-        inter_op_parallelism_threads=1,
-        log_device_placement=log_device_placement,
-        allow_soft_placement=allow_soft_placement,
-    )
-
-    off = rewriter_config_pb2.RewriterConfig.OFF
-    if arithmetic_optimization == "off":
-        session_config.graph_options.rewrite_options.arithmetic_optimization = off
-
-    if memory_optimization == "off":
-        session_config.graph_options.rewrite_options.memory_optimization = off
-
-    if allow_growth is not None:
-        session_config.gpu_options.allow_growth = allow_growth
-        session_config.gpu_options.per_process_gpu_memory_fraction = 0.8
-
-    # The below tf.set_random_seed() will make random number generation
-    # in the TensorFlow backend have a well-defined initial state.
-    # For further details, see:
-    # https://www.tensorflow.org/api_docs/python/tf/set_random_seed
-    tf.compat.v1.set_random_seed(seed)
-    # sess = tf.Session(graph=tf.get_default_graph(), config=session_config)
-    # keras.backend.set_session(sess)
-
-    run_config = tf.estimator.RunConfig()
-    run_config = run_config.replace(session_config=session_config)
-    run_config = run_config.replace(tf_random_seed=seed)
-
-    return [session_config, run_config, None, None, None]
diff --git a/bob/learn/tensorflow/utils/session.py b/bob/learn/tensorflow/utils/session.py
deleted file mode 100644
index 9b72a6ad1ae17a08f709e110854470efdea578d0..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/utils/session.py
+++ /dev/null
@@ -1,27 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 11 May 2016 09:39:36 CEST
-
-import tensorflow as tf
-from tensorflow.python import debug as tf_debug
-
-from .singleton import Singleton
-
-
-@Singleton
-class Session(object):
-    """
-    Encapsulates a tf.session
-    """
-
-    def __init__(self, debug=False):
-        config = tf.compat.v1.ConfigProto(
-            log_device_placement=False,
-            allow_soft_placement=True,
-            gpu_options=tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.5),
-        )
-        config.gpu_options.allow_growth = True
-        self.session = tf.compat.v1.Session()
-        if debug:
-            self.session = tf_debug.LocalCLIDebugWrapperSession(self.session)
diff --git a/bob/learn/tensorflow/utils/singleton.py b/bob/learn/tensorflow/utils/singleton.py
deleted file mode 100644
index 4e7769d02134cee901c4d49ab29764bdf291fa0b..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/utils/singleton.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# A singleton class decorator, based on http://stackoverflow.com/a/7346105/3301902
-
-
-class Singleton(object):
-    """
-    A non-thread-safe helper class to ease implementing singletons.
-    This should be used as a **decorator** -- not a metaclass -- to the class that should be a singleton.
-
-    The decorated class can define one `__init__` function that takes an arbitrary list of parameters.
-
-    To get the singleton instance, use the :py:meth:`instance` method. Trying to use `__call__` will result in a `TypeError` being raised.
-
-    Limitations:
-
-    * The decorated class cannot be inherited from.
-    * The documentation of the decorated class is replaced with the documentation of this class.
-    """
-
-    def __init__(self, decorated):
-        self._decorated = decorated
-        # see: functools.WRAPPER_ASSIGNMENTS:
-        self.__doc__ = decorated.__doc__
-        self.__name__ = decorated.__name__
-        self.__module__ = decorated.__module__
-        self.__mro__ = decorated.__mro__
-        self.__bases__ = []
-
-        self._instance = None
-
-    def create(self, *args, **kwargs):
-        """Creates the singleton instance, by passing the given parameters to the class' constructor."""
-        # TODO: I still having problems in killing all the elements of the current session
-
-        if self._instance is not None:
-            self._instance.session.close()
-            del self._instance
-        self._instance = self._decorated(*args, **kwargs)
-
-    def instance(self, new=False):
-        """Returns the singleton instance.
-        The function :py:meth:`create` must have been called before."""
-        if self._instance is None or new:
-
-            self.create()
-        return self._instance
-
-    def __call__(self):
-        raise TypeError("Singletons must be accessed through the `instance()` method.")
-
-    def __instancecheck__(self, inst):
-        return isinstance(inst, self._decorated)
diff --git a/bob/learn/tensorflow/utils/train.py b/bob/learn/tensorflow/utils/train.py
deleted file mode 100644
index 999891047c46c7101f38eb3f7bc39c4add37ea7d..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/utils/train.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import tensorflow as tf
-
-
-def check_features(features):
-    if "data" not in features or "key" not in features:
-        raise ValueError(
-            "The input function needs to contain a dictionary with the keys `data` and `key` "
-        )
-    return True
-
-
-def get_trainable_variables(extra_checkpoint, mode=tf.estimator.ModeKeys.TRAIN):
-    """
-    Given the extra_checkpoint dictionary provided to the estimator,
-    extract the content of "trainable_variables".
-
-    If trainable_variables is not provided, all end points are trainable by
-    default.
-    If trainable_variables==[], all end points are NOT trainable.
-    If trainable_variables contains some end_points, ONLY these endpoints will
-    be trainable.
-
-    Attributes
-    ----------
-
-    extra_checkpoint: dict
-      The extra_checkpoint dictionary provided to the estimator
-
-    mode:
-        The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is
-        returned.
-
-    Returns
-    -------
-
-    Returns `None` if **trainable_variables** is not in extra_checkpoint;
-    otherwise returns the content of extra_checkpoint .
-    """
-    if mode != tf.estimator.ModeKeys.TRAIN:
-        return None
-
-    # If you don't set anything, everything is trainable
-    if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint:
-        return None
-
-    return extra_checkpoint["trainable_variables"]
diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py
deleted file mode 100644
index 1338f1fd6f658daccdf5f9284d784ece288e96e6..0000000000000000000000000000000000000000
--- a/bob/learn/tensorflow/utils/util.py
+++ /dev/null
@@ -1,493 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 11 May 2016 09:39:36 CEST
-
-import inspect
-import logging
-
-import numpy
-import tensorflow as tf
-from tensorflow.python.client import device_lib
-from tensorflow.python.framework import function
-
-logger = logging.getLogger(__name__)
-
-
-@function.Defun(tf.float32, tf.float32)
-def norm_grad(x, dy):
-    return tf.expand_dims(dy, -1) * (
-        x / (tf.expand_dims(tf.norm(tensor=x, ord=2, axis=-1), -1) + 1.0e-19)
-    )
-
-
-@function.Defun(tf.float32, grad_func=norm_grad)
-def norm(x):
-    return tf.norm(tensor=x, ord=2, axis=-1)
-
-
-def compute_euclidean_distance(x, y):
-    """
-    Computes the euclidean distance between two tensorflow variables
-    """
-
-    with tf.compat.v1.name_scope("euclidean_distance"):
-        # d = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x, y)), 1))
-        d = norm(tf.subtract(x, y))
-        return d
-
-
-def pdist_safe(A, metric="sqeuclidean"):
-    if metric != "sqeuclidean":
-        raise NotImplementedError()
-    r = tf.reduce_sum(input_tensor=A * A, axis=1)
-    r = tf.reshape(r, [-1, 1])
-    D = r - 2 * tf.matmul(A, A, transpose_b=True) + tf.transpose(a=r)
-    return D
-
-
-def cdist(A, B, metric="sqeuclidean"):
-    if metric != "sqeuclidean":
-        raise NotImplementedError()
-    M1, M2 = tf.shape(input=A)[0], tf.shape(input=B)[0]
-    # code from https://stackoverflow.com/a/43839605/1286165
-    p1 = tf.matmul(
-        tf.expand_dims(tf.reduce_sum(input_tensor=tf.square(A), axis=1), 1),
-        tf.ones(shape=(1, M2)),
-    )
-    p2 = tf.transpose(
-        a=tf.matmul(
-            tf.reshape(tf.reduce_sum(input_tensor=tf.square(B), axis=1), shape=[-1, 1]),
-            tf.ones(shape=(M1, 1)),
-            transpose_b=True,
-        )
-    )
-
-    D = tf.add(p1, p2) - 2 * tf.matmul(A, B, transpose_b=True)
-    return D
-
-
-def load_mnist(perc_train=0.9):
-    numpy.random.seed(0)
-    import bob.db.mnist
-
-    db = bob.db.mnist.Database()
-    raw_data = db.data()
-
-    # data  = raw_data[0].astype(numpy.float64)
-    data = raw_data[0]
-    labels = raw_data[1]
-
-    # Shuffling
-    total_samples = data.shape[0]
-    indexes = numpy.array(range(total_samples))
-    numpy.random.shuffle(indexes)
-
-    # Spliting train and validation
-    n_train = int(perc_train * indexes.shape[0])
-    n_validation = total_samples - n_train
-
-    train_data = data[0:n_train, :].astype("float32") * 0.00390625
-    train_labels = labels[0:n_train]
-
-    validation_data = (
-        data[n_train : n_train + n_validation, :].astype("float32") * 0.00390625
-    )
-    validation_labels = labels[n_train : n_train + n_validation]
-
-    return train_data, train_labels, validation_data, validation_labels
-
-
-def create_mnist_tfrecord(tfrecords_filename, data, labels, n_samples=6000):
-    def _bytes_feature(value):
-        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
-
-    def _int64_feature(value):
-        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
-
-    writer = tf.io.TFRecordWriter(tfrecords_filename)
-
-    for i in range(n_samples):
-        img = data[i]
-        img_raw = img.tostring()
-        feature = {
-            "data": _bytes_feature(img_raw),
-            "label": _int64_feature(labels[i]),
-            "key": _bytes_feature(b"-"),
-        }
-
-        example = tf.train.Example(features=tf.train.Features(feature=feature))
-        writer.write(example.SerializeToString())
-    writer.close()
-
-
-def compute_eer(
-    data_train, labels_train, data_validation, labels_validation, n_classes
-):
-    from scipy.spatial.distance import cosine
-
-    import bob.measure
-
-    # Creating client models
-    models = []
-    for i in range(n_classes):
-        indexes = labels_train == i
-        models.append(numpy.mean(data_train[indexes, :], axis=0))
-
-    # Probing
-    positive_scores = numpy.zeros(shape=0)
-    negative_scores = numpy.zeros(shape=0)
-
-    for i in range(n_classes):
-        # Positive scoring
-        indexes = labels_validation == i
-        positive_data = data_validation[indexes, :]
-        p = [cosine(models[i], positive_data[j]) for j in range(positive_data.shape[0])]
-        positive_scores = numpy.hstack((positive_scores, p))
-
-        # negative scoring
-        indexes = labels_validation != i
-        negative_data = data_validation[indexes, :]
-        n = [cosine(models[i], negative_data[j]) for j in range(negative_data.shape[0])]
-        negative_scores = numpy.hstack((negative_scores, n))
-
-    # Computing performance based on EER
-    negative_scores = (-1) * negative_scores
-    positive_scores = (-1) * positive_scores
-
-    threshold = bob.measure.eer_threshold(negative_scores, positive_scores)
-    far, frr = bob.measure.farfrr(negative_scores, positive_scores, threshold)
-    eer = (far + frr) / 2.0
-
-    return eer
-
-
-def compute_accuracy(
-    data_train, labels_train, data_validation, labels_validation, n_classes
-):
-    from scipy.spatial.distance import cosine
-
-    # Creating client models
-    models = []
-    for i in range(n_classes):
-        indexes = labels_train == i
-        models.append(numpy.mean(data_train[indexes, :], axis=0))
-
-    # Probing
-    tp = 0
-    for i in range(data_validation.shape[0]):
-
-        d = data_validation[i, :]
-        l = labels_validation[i]
-
-        scores = [cosine(m, d) for m in models]
-        predict = numpy.argmax(scores)
-
-        if predict == l:
-            tp += 1
-
-    return (float(tp) / data_validation.shape[0]) * 100
-
-
-def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"):
-    """"""
-    import tensorflow as tf
-
-    from bob.learn.tensorflow.utils.session import Session
-
-    session = Session.instance(new=False).session
-    inference_graph = architecture.compute_graph(
-        architecture.inference_placeholder, feature_layer=feature_layer, training=False
-    )
-
-    embeddings = numpy.zeros(shape=(image.shape[0], embbeding_dim))
-    for i in range(image.shape[0]):
-        feed_dict = {architecture.inference_placeholder: image[i : i + 1, :, :, :]}
-        embedding = session.run(
-            [tf.nn.l2_normalize(inference_graph, 1, 1e-10)], feed_dict=feed_dict
-        )[0]
-        embedding = numpy.reshape(embedding, numpy.prod(embedding.shape[1:]))
-        embeddings[i] = embedding
-
-    return embeddings
-
-
-def pdist(A):
-    """
-    Compute a pairwise euclidean distance in the same fashion
-    as in scipy.spation.distance.pdist
-    """
-    with tf.compat.v1.name_scope("Pairwisedistance"):
-        ones_1 = tf.reshape(tf.cast(tf.ones_like(A), tf.float32)[:, 0], [1, -1])
-        p1 = tf.matmul(
-            tf.expand_dims(tf.reduce_sum(input_tensor=tf.square(A), axis=1), 1), ones_1
-        )
-
-        ones_2 = tf.reshape(tf.cast(tf.ones_like(A), tf.float32)[:, 0], [-1, 1])
-        p2 = tf.transpose(
-            a=tf.matmul(
-                tf.reshape(
-                    tf.reduce_sum(input_tensor=tf.square(A), axis=1), shape=[-1, 1]
-                ),
-                ones_2,
-                transpose_b=True,
-            )
-        )
-
-        return tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(A, A, transpose_b=True))
-
-
-def predict_using_tensors(embedding, labels, num=None):
-    """
-    Compute the predictions through exhaustive comparisons between
-    embeddings using tensors
-    """
-
-    # Fitting the main diagonal with infs (removing comparisons with the same
-    # sample)
-    inf = tf.cast(tf.ones_like(labels), tf.float32) * numpy.inf
-
-    distances = pdist(embedding)
-    distances = tf.linalg.set_diag(distances, inf)
-    indexes = tf.argmin(input=distances, axis=1)
-    return [labels[i] for i in tf.unstack(indexes, num=num)]
-
-
-def compute_embedding_accuracy_tensors(embedding, labels, num=None):
-    """
-    Compute the accuracy in a closed-set
-
-    **Parameters**
-
-    embeddings: `tf.Tensor`
-      Set of embeddings
-
-    labels: `tf.Tensor`
-      Correspondent labels
-    """
-
-    # Fitting the main diagonal with infs (removing comparisons with the same
-    # sample)
-    predictions = predict_using_tensors(embedding, labels, num=num)
-    matching = [
-        tf.equal(p, l)
-        for p, l in zip(tf.unstack(predictions, num=num), tf.unstack(labels, num=num))
-    ]
-
-    return tf.reduce_sum(input_tensor=tf.cast(matching, tf.uint8)) / len(predictions)
-
-
-def compute_embedding_accuracy(embedding, labels):
-    """
-    Compute the accuracy in a closed-set
-
-    **Parameters**
-
-    embeddings: :any:`numpy.array`
-      Set of embeddings
-
-    labels: :any:`numpy.array`
-      Correspondent labels
-    """
-
-    from scipy.spatial.distance import pdist
-    from scipy.spatial.distance import squareform
-
-    distances = squareform(pdist(embedding))
-
-    n_samples = embedding.shape[0]
-
-    # Fitting the main diagonal with infs (removing comparisons with the same
-    # sample)
-    numpy.fill_diagonal(distances, numpy.inf)
-
-    indexes = distances.argmin(axis=1)
-
-    # Computing the argmin excluding comparisons with the same samples
-    # Basically, we are excluding the main diagonal
-
-    # valid_indexes = distances[distances>0].reshape(n_samples, n_samples-1).argmin(axis=1)
-
-    # Getting the original positions of the indexes in the 1-axis
-    # corrected_indexes = [ i if i<j else i+1 for i, j in zip(valid_indexes, range(n_samples))]
-
-    matching = [labels[i] == labels[j] for i, j in zip(range(n_samples), indexes)]
-    accuracy = sum(matching) / float(n_samples)
-
-    return accuracy
-
-
-def get_available_gpus():
-    """Returns the number of GPU devices that are available.
-
-    Returns
-    -------
-    [str]
-        The names of available GPU devices.
-    """
-    local_device_protos = device_lib.list_local_devices()
-    return [x.name for x in local_device_protos if x.device_type == "GPU"]
-
-
-def to_channels_last(image):
-    """Converts the image to channel_last format. This is the same format as in
-    matplotlib, skimage, and etc.
-
-    Parameters
-    ----------
-    image : `tf.Tensor`
-        At least a 3 dimensional image. If the dimension is more than 3, the
-        last 3 dimensions are assumed to be [C, H, W].
-
-    Returns
-    -------
-    image : `tf.Tensor`
-        The image in [..., H, W, C] format.
-
-    Raises
-    ------
-    ValueError
-        If dim of image is less than 3.
-    """
-    ndim = len(image.shape)
-    if ndim < 3:
-        raise ValueError(
-            "The image needs to be at least 3 dimensional but it " "was {}".format(ndim)
-        )
-    axis_order = [1, 2, 0]
-    shift = ndim - 3
-    axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
-    return tf.transpose(a=image, perm=axis_order)
-
-
-def to_channels_first(image):
-    """Converts the image to channel_first format. This is the same format as
-    in bob.io.image and bob.io.video.
-
-    Parameters
-    ----------
-    image : `tf.Tensor`
-        At least a 3 dimensional image. If the dimension is more than 3, the
-        last 3 dimensions are assumed to be [H, W, C].
-
-    Returns
-    -------
-    image : `tf.Tensor`
-        The image in [..., C, H, W] format.
-
-    Raises
-    ------
-    ValueError
-        If dim of image is less than 3.
-    """
-    ndim = len(image.shape)
-    if ndim < 3:
-        raise ValueError(
-            "The image needs to be at least 3 dimensional but it " "was {}".format(ndim)
-        )
-    axis_order = [2, 0, 1]
-    shift = ndim - 3
-    axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
-    return tf.transpose(a=image, perm=axis_order)
-
-
-to_skimage = to_matplotlib = to_channels_last
-to_bob = to_channels_first
-
-
-def bytes2human(n, format="%(value).1f %(symbol)s", symbols="customary"):
-    """Convert n bytes into a human readable string based on format.
-    From: https://code.activestate.com/recipes/578019-bytes-to-human-human-to-
-    bytes-converter/
-    Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>
-    License: MIT
-    symbols can be either "customary", "customary_ext", "iec" or "iec_ext",
-    see: http://goo.gl/kTQMs
-    """
-    SYMBOLS = {
-        "customary": ("B", "K", "M", "G", "T", "P", "E", "Z", "Y"),
-        "customary_ext": (
-            "byte",
-            "kilo",
-            "mega",
-            "giga",
-            "tera",
-            "peta",
-            "exa",
-            "zetta",
-            "iotta",
-        ),
-        "iec": ("Bi", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"),
-        "iec_ext": (
-            "byte",
-            "kibi",
-            "mebi",
-            "gibi",
-            "tebi",
-            "pebi",
-            "exbi",
-            "zebi",
-            "yobi",
-        ),
-    }
-    n = int(n)
-    if n < 0:
-        raise ValueError("n < 0")
-    symbols = SYMBOLS[symbols]
-    prefix = {}
-    for i, s in enumerate(symbols[1:]):
-        prefix[s] = 1 << (i + 1) * 10
-    for symbol in reversed(symbols[1:]):
-        if n >= prefix[symbol]:
-            value = float(n) / prefix[symbol]
-            return format % locals()
-    return format % dict(symbol=symbols[0], value=n)
-
-
-def random_choice_no_replacement(one_dim_input, num_indices_to_drop=3, sort=False):
-    """Similar to np.random.choice with no replacement.
-    Code from https://stackoverflow.com/a/54755281/1286165
-    """
-    input_length = tf.shape(input=one_dim_input)[0]
-
-    # create uniform distribution over the sequence
-    uniform_distribution = tf.random.uniform(
-        shape=[input_length],
-        minval=0,
-        maxval=None,
-        dtype=tf.float32,
-        seed=None,
-        name=None,
-    )
-
-    # grab the indices of the greatest num_words_to_drop values from the distibution
-    _, indices_to_keep = tf.nn.top_k(
-        uniform_distribution, input_length - num_indices_to_drop
-    )
-
-    # sort the indices
-    if sort:
-        sorted_indices_to_keep = tf.sort(indices_to_keep)
-    else:
-        sorted_indices_to_keep = indices_to_keep
-
-    # gather indices from the input array using the filtered actual array
-    result = tf.gather(one_dim_input, sorted_indices_to_keep)
-    return result
-
-
-def is_argument_available(argument, method):
-    """
-    Check if an argument (or keyword argument) is available in a method
-
-    Attributes
-    ----------
-    argument: str
-        The name of the argument (or keyword argument).
-
-    method:
-        Pointer to the method
-
-    """
-
-    return argument in inspect.signature(method).parameters.keys()