diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py
index 93eb960fa720b1cfb93b96f9b0bfefc8cf6fc6f2..ab34eb789855f2fcdbf8cda83d356d80b1165530 100644
--- a/bob/learn/tensorflow/dataset/__init__.py
+++ b/bob/learn/tensorflow/dataset/__init__.py
@@ -4,9 +4,9 @@ import os
 import bob.io.base
 
 DEFAULT_FEATURE = {
-    "data": tf.FixedLenFeature([], tf.string),
-    "label": tf.FixedLenFeature([], tf.int64),
-    "key": tf.FixedLenFeature([], tf.string),
+    "data": tf.io.FixedLenFeature([], tf.string),
+    "label": tf.io.FixedLenFeature([], tf.int64),
+    "key": tf.io.FixedLenFeature([], tf.string),
 }
 
 
@@ -32,110 +32,110 @@ def from_filename_to_tensor(filename, extension=None):
     """
 
     if extension == "hdf5":
-        return tf.py_func(from_hdf5file_to_tensor, [filename], [tf.float32])
+        return tf.compat.v1.py_func(from_hdf5file_to_tensor, [filename], [tf.float32])
     else:
-        return tf.cast(tf.image.decode_image(tf.read_file(filename)), tf.float32)
-
-
-def append_image_augmentation(
-    image,
-    gray_scale=False,
-    output_shape=None,
-    random_flip=False,
-    random_brightness=False,
-    random_contrast=False,
-    random_saturation=False,
-    random_rotate=False,
-    per_image_normalization=True,
-    random_gamma=False,
-    random_crop=False,
-):
-    """
-    Append to the current tensor some random image augmentation operation
-
-    **Parameters**
-       gray_scale:
-          Convert to gray scale?
-
-       output_shape:
-          If set, will randomly crop the image given the output shape
-
-       random_flip:
-          Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
-
-       random_brightness:
-           Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
-
-       random_contrast:
-           Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
-
-       random_saturation:
-           Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
-
-       random_rotate:
-           Randomly rotate face images between -5 and 5 degrees
-
-       per_image_normalization:
-           Linearly scales image to have zero mean and unit norm.
-
-    """
-
-    # Changing the range from 0-255 to 0-1
-    image = tf.cast(image, tf.float32) / 255
-    # FORCING A SEED FOR THE RANDOM OPERATIONS
-    tf.set_random_seed(0)
-
-    if output_shape is not None:
-        assert len(output_shape) == 2
-        if random_crop:
-            image = tf.random_crop(image, size=list(output_shape) + [3])
-        else:
-            image = tf.image.resize_image_with_crop_or_pad(
-                image, output_shape[0], output_shape[1]
-            )
-
-    if random_flip:
-        image = tf.image.random_flip_left_right(image)
-
-    if random_brightness:
-        image = tf.image.random_brightness(image, max_delta=0.15)
-        image = tf.clip_by_value(image, 0, 1)
-
-    if random_contrast:
-        image = tf.image.random_contrast(image, lower=0.85, upper=1.15)
-        image = tf.clip_by_value(image, 0, 1)
-
-    if random_saturation:
-        image = tf.image.random_saturation(image, lower=0.85, upper=1.15)
-        image = tf.clip_by_value(image, 0, 1)
-
-    if random_gamma:
-        image = tf.image.adjust_gamma(
-            image, gamma=tf.random.uniform(shape=[], minval=0.85, maxval=1.15)
-        )
-        image = tf.clip_by_value(image, 0, 1)
-
-    if random_rotate:
-        # from https://stackoverflow.com/a/53855704/1286165
-        degree = 0.08726646259971647  # math.pi * 5 /180
-        random_angles = tf.random.uniform(shape=(1,), minval=-degree, maxval=degree)
-        image = tf.contrib.image.transform(
-            image,
-            tf.contrib.image.angles_to_projective_transforms(
-                random_angles,
-                tf.cast(tf.shape(image)[-3], tf.float32),
-                tf.cast(tf.shape(image)[-2], tf.float32),
-            ),
-        )
-
-    if gray_scale:
-        image = tf.image.rgb_to_grayscale(image, name="rgb_to_gray")
-
-    # normalizing data
-    if per_image_normalization:
-        image = tf.image.per_image_standardization(image)
-
-    return image
+        return tf.cast(tf.image.decode_image(tf.io.read_file(filename)), tf.float32)
+
+
+# def append_image_augmentation(
+#     image,
+#     gray_scale=False,
+#     output_shape=None,
+#     random_flip=False,
+#     random_brightness=False,
+#     random_contrast=False,
+#     random_saturation=False,
+#     random_rotate=False,
+#     per_image_normalization=True,
+#     random_gamma=False,
+#     random_crop=False,
+# ):
+#     """
+#     Append to the current tensor some random image augmentation operation
+
+#     **Parameters**
+#        gray_scale:
+#           Convert to gray scale?
+
+#        output_shape:
+#           If set, will randomly crop the image given the output shape
+
+#        random_flip:
+#           Randomly flip an image horizontally  (https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right)
+
+#        random_brightness:
+#            Adjust the brightness of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_brightness)
+
+#        random_contrast:
+#            Adjust the contrast of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_contrast)
+
+#        random_saturation:
+#            Adjust the saturation of an RGB image by a random factor (https://www.tensorflow.org/api_docs/python/tf/image/random_saturation)
+
+#        random_rotate:
+#            Randomly rotate face images between -5 and 5 degrees
+
+#        per_image_normalization:
+#            Linearly scales image to have zero mean and unit norm.
+
+#     """
+
+#     # Changing the range from 0-255 to 0-1
+#     image = tf.cast(image, tf.float32) / 255
+#     # FORCING A SEED FOR THE RANDOM OPERATIONS
+#     tf.compat.v1.set_random_seed(0)
+
+#     if output_shape is not None:
+#         assert len(output_shape) == 2
+#         if random_crop:
+#             image = tf.image.random_crop(image, size=list(output_shape) + [3])
+#         else:
+#             image = tf.image.resize_with_crop_or_pad(
+#                 image, output_shape[0], output_shape[1]
+#             )
+
+#     if random_flip:
+#         image = tf.image.random_flip_left_right(image)
+
+#     if random_brightness:
+#         image = tf.image.random_brightness(image, max_delta=0.15)
+#         image = tf.clip_by_value(image, 0, 1)
+
+#     if random_contrast:
+#         image = tf.image.random_contrast(image, lower=0.85, upper=1.15)
+#         image = tf.clip_by_value(image, 0, 1)
+
+#     if random_saturation:
+#         image = tf.image.random_saturation(image, lower=0.85, upper=1.15)
+#         image = tf.clip_by_value(image, 0, 1)
+
+#     if random_gamma:
+#         image = tf.image.adjust_gamma(
+#             image, gamma=tf.random.uniform(shape=[], minval=0.85, maxval=1.15)
+#         )
+#         image = tf.clip_by_value(image, 0, 1)
+
+#     if random_rotate:
+#         # from https://stackoverflow.com/a/53855704/1286165
+#         degree = 0.08726646259971647  # math.pi * 5 /180
+#         random_angles = tf.random.uniform(shape=(1,), minval=-degree, maxval=degree)
+#         image = tf.contrib.image.transform(
+#             image,
+#             tf.contrib.image.angles_to_projective_transforms(
+#                 random_angles,
+#                 tf.cast(tf.shape(input=image)[-3], tf.float32),
+#                 tf.cast(tf.shape(input=image)[-2], tf.float32),
+#             ),
+#         )
+
+#     if gray_scale:
+#         image = tf.image.rgb_to_grayscale(image, name="rgb_to_gray")
+
+#     # normalizing data
+#     if per_image_normalization:
+#         image = tf.image.per_image_standardization(image)
+
+#     return image
 
 
 def arrange_indexes_by_label(input_labels, possible_labels):
@@ -343,7 +343,7 @@ def blocks_tensorflow(images, block_size):
     output_size = list(block_size)
     output_size[0] = -1
     output_size[-1] = images.shape[-1]
-    blocks = tf.extract_image_patches(
+    blocks = tf.image.extract_patches(
         images, block_size, block_size, [1, 1, 1, 1], "VALID"
     )
     n_blocks = int(numpy.prod(blocks.shape[1:3]))
@@ -366,11 +366,11 @@ def tf_repeat(tensor, repeats):
     A Tensor. Has the same type as input. Has the shape of tensor.shape *
     repeats
     """
-    with tf.variable_scope("repeat"):
+    with tf.compat.v1.variable_scope("repeat"):
         expanded_tensor = tf.expand_dims(tensor, -1)
         multiples = [1] + repeats
         tiled_tensor = tf.tile(expanded_tensor, multiples=multiples)
-        repeated_tesnor = tf.reshape(tiled_tensor, tf.shape(tensor) * repeats)
+        repeated_tesnor = tf.reshape(tiled_tensor, tf.shape(input=tensor) * repeats)
     return repeated_tesnor
 
 
diff --git a/bob/learn/tensorflow/dataset/image.py b/bob/learn/tensorflow/dataset/image.py
index b731e55260cac42d6e56edec37f4f6cfb65627f7..bc8c315e2ca8729100406e5c7f33af26d838ecb0 100644
--- a/bob/learn/tensorflow/dataset/image.py
+++ b/bob/learn/tensorflow/dataset/image.py
@@ -99,7 +99,7 @@ def shuffle_data_and_labels_image_augmentation(filenames,
 
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
 
-    data, labels = dataset.make_one_shot_iterator().get_next()
+    data, labels = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
     return data, labels
 
 
@@ -215,7 +215,7 @@ def load_pngs(img_path, img_shape):
     object
         The loaded png file
     """
-    img_raw = tf.read_file(img_path)
+    img_raw = tf.io.read_file(img_path)
     img_tensor = tf.image.decode_png(img_raw, channels=img_shape[-1])
     img_final = tf.reshape(img_tensor, img_shape)
     return img_final
diff --git a/bob/learn/tensorflow/dataset/siamese_image.py b/bob/learn/tensorflow/dataset/siamese_image.py
index 51a56a64b5290b3d853a3be849cefa261e4e4266..cacad02c971ca3997f823a15957aeb2634bf01e7 100644
--- a/bob/learn/tensorflow/dataset/siamese_image.py
+++ b/bob/learn/tensorflow/dataset/siamese_image.py
@@ -103,7 +103,7 @@ def shuffle_data_and_labels_image_augmentation(filenames,
         extension=extension)
 
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
-    data, labels = dataset.make_one_shot_iterator().get_next()
+    data, labels = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
     return data, labels
 
 
diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index 45201b88c0b732de17d87c8510f158cf653c4af9..543b7c7cfac57d8ac5fe30d7a751968971a3619b 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -77,9 +77,9 @@ def dataset_to_tfrecord(dataset, output):
         return example_proto.SerializeToString()
 
     def tf_serialize_example(*args):
-        args = tf.contrib.framework.nest.flatten(args)
-        args = [tf.serialize_tensor(f) for f in args]
-        tf_string = tf.py_func(serialize_example_pyfunction, args, tf.string)
+        args = tf.nest.flatten(args)
+        args = [tf.io.serialize_tensor(f) for f in args]
+        tf_string = tf.compat.v1.py_func(serialize_example_pyfunction, args, tf.string)
         return tf.reshape(tf_string, ())  # The result is a scalar
 
     dataset = dataset.map(tf_serialize_example)
@@ -122,20 +122,20 @@ def dataset_from_tfrecord(tfrecord, num_parallel_reads=None):
         meta = json.load(f)
     for k, v in meta.items():
         meta[k] = eval(v)
-    output_types = tf.contrib.framework.nest.flatten(meta["output_types"])
-    output_shapes = tf.contrib.framework.nest.flatten(meta["output_shapes"])
+    output_types = tf.nest.flatten(meta["output_types"])
+    output_shapes = tf.nest.flatten(meta["output_shapes"])
     feature_description = {}
     for i in range(len(output_types)):
         key = f"feature{i}"
-        feature_description[key] = tf.FixedLenFeature([], tf.string)
+        feature_description[key] = tf.io.FixedLenFeature([], tf.string)
 
     def _parse_function(example_proto):
         # Parse the input tf.Example proto using the dictionary above.
-        args = tf.parse_single_example(example_proto, feature_description)
-        args = tf.contrib.framework.nest.flatten(args)
-        args = [tf.parse_tensor(v, t) for v, t in zip(args, output_types)]
+        args = tf.io.parse_single_example(serialized=example_proto, features=feature_description)
+        args = tf.nest.flatten(args)
+        args = [tf.io.parse_tensor(v, t) for v, t in zip(args, output_types)]
         args = [tf.reshape(v, s) for v, s in zip(args, output_shapes)]
-        return tf.contrib.framework.nest.pack_sequence_as(meta["output_types"], args)
+        return tf.nest.pack_sequence_as(meta["output_types"], args)
 
     return raw_dataset.map(_parse_function)
 
@@ -161,9 +161,9 @@ def example_parser(serialized_example, feature, data_shape, data_type):
 
   """
     # Decode the record read by the reader
-    features = tf.parse_single_example(serialized_example, features=feature)
+    features = tf.io.parse_single_example(serialized=serialized_example, features=feature)
     # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features["data"], data_type)
+    image = tf.io.decode_raw(features["data"], data_type)
     # Cast label data into int64
     label = tf.cast(features["label"], tf.int64)
     # Reshape image data into the original shape
@@ -193,9 +193,9 @@ def image_augmentation_parser(
 
   """
     # Decode the record read by the reader
-    features = tf.parse_single_example(serialized_example, features=feature)
+    features = tf.io.parse_single_example(serialized=serialized_example, features=feature)
     # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features["data"], data_type)
+    image = tf.io.decode_raw(features["data"], data_type)
 
     # Reshape image data into the original shape
     image = tf.reshape(image, data_shape)
@@ -231,7 +231,7 @@ def read_and_decode(filename_queue, data_shape, data_type=tf.float32, feature=No
     if feature is None:
         feature = DEFAULT_FEATURE
     # Define a reader and read the next record
-    reader = tf.TFRecordReader()
+    reader = tf.compat.v1.TFRecordReader()
     _, serialized_example = reader.read(filename_queue)
     return example_parser(serialized_example, feature, data_shape, data_type)
 
@@ -459,7 +459,7 @@ def shuffle_data_and_labels(
     dataset = create_dataset_from_records(tfrecord_filenames, data_shape, data_type)
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
 
-    data, labels, key = dataset.make_one_shot_iterator().get_next()
+    data, labels, key = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
     features = dict()
     features["data"] = data
     features["key"] = key
@@ -495,7 +495,7 @@ def batch_data_and_labels(
     dataset = create_dataset_from_records(tfrecord_filenames, data_shape, data_type)
     dataset = dataset.batch(batch_size).repeat(epochs)
 
-    data, labels, key = dataset.make_one_shot_iterator().get_next()
+    data, labels, key = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
     features = dict()
     features["data"] = data
     features["key"] = key
@@ -565,7 +565,7 @@ def batch_data_and_labels_image_augmentation(
     dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
     dataset = dataset.repeat(epochs)
 
-    data, labels, key = dataset.make_one_shot_iterator().get_next()
+    data, labels, key = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
     features = dict()
     features["data"] = data
     features["key"] = key
@@ -602,26 +602,26 @@ def describe_tf_record(tf_record_path, shape, batch_size=1):
   """
 
     tf_records = [os.path.join(tf_record_path, f) for f in os.listdir(tf_record_path)]
-    filename_queue = tf.train.string_input_producer(
+    filename_queue = tf.compat.v1.train.string_input_producer(
         tf_records, num_epochs=1, name="input"
     )
 
     feature = {
-        "data": tf.FixedLenFeature([], tf.string),
-        "label": tf.FixedLenFeature([], tf.int64),
-        "key": tf.FixedLenFeature([], tf.string),
+        "data": tf.io.FixedLenFeature([], tf.string),
+        "label": tf.io.FixedLenFeature([], tf.int64),
+        "key": tf.io.FixedLenFeature([], tf.string),
     }
 
     # Define a reader and read the next record
-    reader = tf.TFRecordReader()
+    reader = tf.compat.v1.TFRecordReader()
 
     _, serialized_example = reader.read(filename_queue)
 
     # Decode the record read by the reader
-    features = tf.parse_single_example(serialized_example, features=feature)
+    features = tf.io.parse_single_example(serialized=serialized_example, features=feature)
 
     # Convert the image data from string back to the numbers
-    image = tf.decode_raw(features["data"], tf.uint8)
+    image = tf.io.decode_raw(features["data"], tf.uint8)
 
     # Cast label data into int32
     label = tf.cast(features["label"], tf.int64)
@@ -631,7 +631,7 @@ def describe_tf_record(tf_record_path, shape, batch_size=1):
     image = tf.reshape(image, shape)
 
     # Getting the batches in order
-    data_ph, label_ph, img_name_ph = tf.train.batch(
+    data_ph, label_ph, img_name_ph = tf.compat.v1.train.batch(
         [image, label, img_name],
         batch_size=batch_size,
         capacity=1000,
@@ -640,13 +640,13 @@ def describe_tf_record(tf_record_path, shape, batch_size=1):
     )
 
     # Start the reading
-    session = tf.Session()
-    tf.local_variables_initializer().run(session=session)
-    tf.global_variables_initializer().run(session=session)
+    session = tf.compat.v1.Session()
+    tf.compat.v1.local_variables_initializer().run(session=session)
+    tf.compat.v1.global_variables_initializer().run(session=session)
 
     # Preparing the batches
     thread_pool = tf.train.Coordinator()
-    threads = tf.train.start_queue_runners(coord=thread_pool, sess=session)
+    threads = tf.compat.v1.train.start_queue_runners(coord=thread_pool, sess=session)
 
     logger.info("Counting in %s", tf_record_path)
     labels = set()
diff --git a/bob/learn/tensorflow/dataset/triplet_image.py b/bob/learn/tensorflow/dataset/triplet_image.py
index 944641100a8a2da2791b363e957beea0dac65a51..159d0a7c07f86997b495566eb225f792cddea6e4 100644
--- a/bob/learn/tensorflow/dataset/triplet_image.py
+++ b/bob/learn/tensorflow/dataset/triplet_image.py
@@ -104,7 +104,7 @@ def shuffle_data_and_labels_image_augmentation(filenames,
     dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
     #dataset = dataset.batch(buffer_size).batch(batch_size).repeat(epochs)
 
-    data = dataset.make_one_shot_iterator().get_next()
+    data = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
     return data
 
 
diff --git a/bob/learn/tensorflow/gan/losses.py b/bob/learn/tensorflow/gan/losses.py
index ec378245953d60c96d4634eb797d6d1c30e2dec0..46be1eaa8c76b3a253968d609a1948fc08a013cd 100644
--- a/bob/learn/tensorflow/gan/losses.py
+++ b/bob/learn/tensorflow/gan/losses.py
@@ -8,8 +8,8 @@ def relativistic_discriminator_loss(
     real_weights=1.0,
     generated_weights=1.0,
     scope=None,
-    loss_collection=tf.GraphKeys.LOSSES,
-    reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
+    loss_collection=tf.compat.v1.GraphKeys.LOSSES,
+    reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
     add_summaries=False,
 ):
     """Relativistic (average) loss
@@ -34,7 +34,7 @@ def relativistic_discriminator_loss(
   Returns:
     A loss Tensor. The shape depends on `reduction`.
   """
-    with tf.name_scope(
+    with tf.compat.v1.name_scope(
         scope,
         "discriminator_relativistic_loss",
         (
@@ -47,13 +47,13 @@ def relativistic_discriminator_loss(
     ) as scope:
 
         real_logit = discriminator_real_outputs - tf.reduce_mean(
-            discriminator_gen_outputs
+            input_tensor=discriminator_gen_outputs
         )
         fake_logit = discriminator_gen_outputs - tf.reduce_mean(
-            discriminator_real_outputs
+            input_tensor=discriminator_real_outputs
         )
 
-        loss_on_real = tf.losses.sigmoid_cross_entropy(
+        loss_on_real = tf.compat.v1.losses.sigmoid_cross_entropy(
             tf.ones_like(real_logit),
             real_logit,
             real_weights,
@@ -62,7 +62,7 @@ def relativistic_discriminator_loss(
             loss_collection=None,
             reduction=reduction,
         )
-        loss_on_generated = tf.losses.sigmoid_cross_entropy(
+        loss_on_generated = tf.compat.v1.losses.sigmoid_cross_entropy(
             tf.zeros_like(fake_logit),
             fake_logit,
             generated_weights,
@@ -72,12 +72,12 @@ def relativistic_discriminator_loss(
         )
 
         loss = loss_on_real + loss_on_generated
-        tf.losses.add_loss(loss, loss_collection)
+        tf.compat.v1.losses.add_loss(loss, loss_collection)
 
         if add_summaries:
-            tf.summary.scalar("discriminator_gen_relativistic_loss", loss_on_generated)
-            tf.summary.scalar("discriminator_real_relativistic_loss", loss_on_real)
-            tf.summary.scalar("discriminator_relativistic_loss", loss)
+            tf.compat.v1.summary.scalar("discriminator_gen_relativistic_loss", loss_on_generated)
+            tf.compat.v1.summary.scalar("discriminator_real_relativistic_loss", loss_on_real)
+            tf.compat.v1.summary.scalar("discriminator_relativistic_loss", loss)
 
     return loss
 
@@ -89,8 +89,8 @@ def relativistic_generator_loss(
     real_weights=1.0,
     generated_weights=1.0,
     scope=None,
-    loss_collection=tf.GraphKeys.LOSSES,
-    reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
+    loss_collection=tf.compat.v1.GraphKeys.LOSSES,
+    reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
     add_summaries=False,
     confusion_labels=False,
 ):
@@ -116,7 +116,7 @@ def relativistic_generator_loss(
   Returns:
     A loss Tensor. The shape depends on `reduction`.
   """
-    with tf.name_scope(
+    with tf.compat.v1.name_scope(
         scope,
         "generator_relativistic_loss",
         (
@@ -129,10 +129,10 @@ def relativistic_generator_loss(
     ) as scope:
 
         real_logit = discriminator_real_outputs - tf.reduce_mean(
-            discriminator_gen_outputs
+            input_tensor=discriminator_gen_outputs
         )
         fake_logit = discriminator_gen_outputs - tf.reduce_mean(
-            discriminator_real_outputs
+            input_tensor=discriminator_real_outputs
         )
 
         if confusion_labels:
@@ -142,7 +142,7 @@ def relativistic_generator_loss(
             real_labels = tf.zeros_like(real_logit)
             fake_labels = tf.ones_like(fake_logit)
 
-        loss_on_real = tf.losses.sigmoid_cross_entropy(
+        loss_on_real = tf.compat.v1.losses.sigmoid_cross_entropy(
             real_labels,
             real_logit,
             real_weights,
@@ -151,7 +151,7 @@ def relativistic_generator_loss(
             loss_collection=None,
             reduction=reduction,
         )
-        loss_on_generated = tf.losses.sigmoid_cross_entropy(
+        loss_on_generated = tf.compat.v1.losses.sigmoid_cross_entropy(
             fake_labels,
             fake_logit,
             generated_weights,
@@ -161,11 +161,11 @@ def relativistic_generator_loss(
         )
 
         loss = loss_on_real + loss_on_generated
-        tf.losses.add_loss(loss, loss_collection)
+        tf.compat.v1.losses.add_loss(loss, loss_collection)
 
         if add_summaries:
-            tf.summary.scalar("generator_gen_relativistic_loss", loss_on_generated)
-            tf.summary.scalar("generator_real_relativistic_loss", loss_on_real)
-            tf.summary.scalar("generator_relativistic_loss", loss)
+            tf.compat.v1.summary.scalar("generator_gen_relativistic_loss", loss_on_generated)
+            tf.compat.v1.summary.scalar("generator_real_relativistic_loss", loss_on_real)
+            tf.compat.v1.summary.scalar("generator_relativistic_loss", loss)
 
     return loss
diff --git a/bob/learn/tensorflow/image/filter.py b/bob/learn/tensorflow/image/filter.py
index 3ac149db3113cf5166d46fe5c4ca80ed11052c2c..f77fbcd18896c070e10af8daff020a2e34337bab 100644
--- a/bob/learn/tensorflow/image/filter.py
+++ b/bob/learn/tensorflow/image/filter.py
@@ -5,13 +5,13 @@ def gaussian_kernel(size: int, mean: float, std: float):
     """Makes 2D gaussian Kernel for convolution.
     Code adapted from: https://stackoverflow.com/a/52012658/1286165"""
 
-    d = tf.distributions.Normal(mean, std)
+    d = tf.compat.v1.distributions.Normal(mean, std)
 
     vals = d.prob(tf.range(start=-size, limit=size + 1, dtype=tf.float32))
 
     gauss_kernel = tf.einsum("i,j->ij", vals, vals)
 
-    return gauss_kernel / tf.reduce_sum(gauss_kernel)
+    return gauss_kernel / tf.reduce_sum(input_tensor=gauss_kernel)
 
 
 class GaussianFilter:
@@ -25,13 +25,13 @@ class GaussianFilter:
         self.gauss_kernel = gaussian_kernel(size, mean, std)[:, :, None, None]
 
     def __call__(self, image):
-        shape = tf.shape(image)
+        shape = tf.shape(input=image)
         image = tf.reshape(image, [-1, shape[-3], shape[-2], shape[-1]])
         input_channels = shape[-1]
         gauss_kernel = tf.tile(self.gauss_kernel, [1, 1, input_channels, 1])
         return tf.nn.depthwise_conv2d(
-            image,
-            gauss_kernel,
+            input=image,
+            filter=gauss_kernel,
             strides=[1, 1, 1, 1],
             padding="SAME",
             data_format="NHWC",
diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py
index ede28207485a00ad6a0b3d978e9a050d246b420f..7cf4193aed854c3e62e669afaef86ccdf431a6f7 100644
--- a/bob/learn/tensorflow/loss/BaseLoss.py
+++ b/bob/learn/tensorflow/loss/BaseLoss.py
@@ -19,18 +19,18 @@ def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True):
 
     """
 
-    with tf.variable_scope('cross_entropy_loss'):
+    with tf.compat.v1.variable_scope('cross_entropy_loss'):
         cross_loss = tf.reduce_mean(
-            tf.nn.sparse_softmax_cross_entropy_with_logits(
+            input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                 logits=logits, labels=labels),
             name="cross_entropy_loss")
 
-        tf.summary.scalar('cross_entropy_loss', cross_loss)
-        tf.add_to_collection(tf.GraphKeys.LOSSES, cross_loss)
+        tf.compat.v1.summary.scalar('cross_entropy_loss', cross_loss)
+        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, cross_loss)
 
         if add_regularization_losses:
-            regularization_losses = tf.get_collection(
-                tf.GraphKeys.REGULARIZATION_LOSSES)
+            regularization_losses = tf.compat.v1.get_collection(
+                tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
 
             total_loss = tf.add_n(
                 [cross_loss] + regularization_losses, name="total_loss")
@@ -59,41 +59,41 @@ def mean_cross_entropy_center_loss(logits,
 
     """
     # Cross entropy
-    with tf.variable_scope('cross_entropy_loss'):
+    with tf.compat.v1.variable_scope('cross_entropy_loss'):
         cross_loss = tf.reduce_mean(
-            tf.nn.sparse_softmax_cross_entropy_with_logits(
+            input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                 logits=logits, labels=labels),
             name="cross_entropy_loss")
-        tf.add_to_collection(tf.GraphKeys.LOSSES, cross_loss)
-        tf.summary.scalar('loss_cross_entropy', cross_loss)
+        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, cross_loss)
+        tf.compat.v1.summary.scalar('loss_cross_entropy', cross_loss)
 
     # Appending center loss
-    with tf.variable_scope('center_loss'):
+    with tf.compat.v1.variable_scope('center_loss'):
         n_features = prelogits.get_shape()[1]
 
-        centers = tf.get_variable(
+        centers = tf.compat.v1.get_variable(
             'centers', [n_classes, n_features],
             dtype=tf.float32,
-            initializer=tf.constant_initializer(0),
+            initializer=tf.compat.v1.constant_initializer(0),
             trainable=False)
 
         # label = tf.reshape(labels, [-1])
         centers_batch = tf.gather(centers, labels)
         diff = (1 - alpha) * (centers_batch - prelogits)
-        centers = tf.scatter_sub(centers, labels, diff)
-        center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))
-        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
+        centers = tf.compat.v1.scatter_sub(centers, labels, diff)
+        center_loss = tf.reduce_mean(input_tensor=tf.square(prelogits - centers_batch))
+        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES,
                              center_loss * factor)
-        tf.summary.scalar('loss_center', center_loss)
+        tf.compat.v1.summary.scalar('loss_center', center_loss)
 
     # Adding the regularizers in the loss
-    with tf.variable_scope('total_loss'):
-        regularization_losses = tf.get_collection(
-            tf.GraphKeys.REGULARIZATION_LOSSES)
+    with tf.compat.v1.variable_scope('total_loss'):
+        regularization_losses = tf.compat.v1.get_collection(
+            tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
         total_loss = tf.add_n(
             [cross_loss] + regularization_losses, name="total_loss")
-        tf.add_to_collection(tf.GraphKeys.LOSSES, total_loss)
-        tf.summary.scalar('loss_total', total_loss)
+        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, total_loss)
+        tf.compat.v1.summary.scalar('loss_total', total_loss)
 
     loss = dict()
     loss['loss'] = total_loss
diff --git a/bob/learn/tensorflow/loss/ContrastiveLoss.py b/bob/learn/tensorflow/loss/ContrastiveLoss.py
index 6fa29f1aefe140d6e09de5c876958735f5e5508b..d484b025bb7658403ae106df15444729016595cc 100644
--- a/bob/learn/tensorflow/loss/ContrastiveLoss.py
+++ b/bob/learn/tensorflow/loss/ContrastiveLoss.py
@@ -36,35 +36,35 @@ def contrastive_loss(left_embedding, right_embedding, labels, contrastive_margin
 
     """
 
-    with tf.name_scope("contrastive_loss"):
-        labels = tf.to_float(labels)
+    with tf.compat.v1.name_scope("contrastive_loss"):
+        labels = tf.cast(labels, dtype=tf.float32)
 
         left_embedding = tf.nn.l2_normalize(left_embedding, 1)
         right_embedding = tf.nn.l2_normalize(right_embedding, 1)
 
         d = compute_euclidean_distance(left_embedding, right_embedding)
 
-        with tf.name_scope("within_class"):
+        with tf.compat.v1.name_scope("within_class"):
             one = tf.constant(1.0)
             within_class = tf.multiply(one - labels, tf.square(d))  # (1-Y)*(d^2)
-            within_class_loss = tf.reduce_mean(within_class, name="within_class")
-            tf.add_to_collection(tf.GraphKeys.LOSSES, within_class_loss)
+            within_class_loss = tf.reduce_mean(input_tensor=within_class, name="within_class")
+            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, within_class_loss)
 
-        with tf.name_scope("between_class"):
+        with tf.compat.v1.name_scope("between_class"):
             max_part = tf.square(tf.maximum(contrastive_margin - d, 0))
             between_class = tf.multiply(
                 labels, max_part
             )  # (Y) * max((margin - d)^2, 0)
-            between_class_loss = tf.reduce_mean(between_class, name="between_class")
-            tf.add_to_collection(tf.GraphKeys.LOSSES, between_class_loss)
+            between_class_loss = tf.reduce_mean(input_tensor=between_class, name="between_class")
+            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, between_class_loss)
 
-        with tf.name_scope("total_loss"):
+        with tf.compat.v1.name_scope("total_loss"):
             loss = 0.5 * (within_class + between_class)
-            loss = tf.reduce_mean(loss, name="contrastive_loss")
-            tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
+            loss = tf.reduce_mean(input_tensor=loss, name="contrastive_loss")
+            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, loss)
 
-        tf.summary.scalar("contrastive_loss", loss)
-        tf.summary.scalar("between_class", between_class_loss)
-        tf.summary.scalar("within_class", within_class_loss)
+        tf.compat.v1.summary.scalar("contrastive_loss", loss)
+        tf.compat.v1.summary.scalar("between_class", between_class_loss)
+        tf.compat.v1.summary.scalar("within_class", within_class_loss)
 
         return loss
diff --git a/bob/learn/tensorflow/loss/TripletLoss.py b/bob/learn/tensorflow/loss/TripletLoss.py
index d2616d6aa0394f1ea024b65b9199292221951ca4..14dc4e98e0e16db2f4d67d862df7bc89993dab39 100644
--- a/bob/learn/tensorflow/loss/TripletLoss.py
+++ b/bob/learn/tensorflow/loss/TripletLoss.py
@@ -38,7 +38,7 @@ def triplet_loss(anchor_embedding,
 
     """
 
-    with tf.name_scope("triplet_loss"):
+    with tf.compat.v1.name_scope("triplet_loss"):
         # Normalize
         anchor_embedding = tf.nn.l2_normalize(
             anchor_embedding, 1, 1e-10, name="anchor")
@@ -48,28 +48,28 @@ def triplet_loss(anchor_embedding,
             negative_embedding, 1, 1e-10, name="negative")
 
         d_positive = tf.reduce_sum(
-            tf.square(tf.subtract(anchor_embedding, positive_embedding)), 1)
+            input_tensor=tf.square(tf.subtract(anchor_embedding, positive_embedding)), axis=1)
         d_negative = tf.reduce_sum(
-            tf.square(tf.subtract(anchor_embedding, negative_embedding)), 1)
+            input_tensor=tf.square(tf.subtract(anchor_embedding, negative_embedding)), axis=1)
 
         basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin)
 
-        with tf.name_scope("TripletLoss"):
+        with tf.compat.v1.name_scope("TripletLoss"):
             # Between
-            between_class_loss = tf.reduce_mean(d_negative)
-            tf.summary.scalar('loss_between_class', between_class_loss)
-            tf.add_to_collection(tf.GraphKeys.LOSSES, between_class_loss)
+            between_class_loss = tf.reduce_mean(input_tensor=d_negative)
+            tf.compat.v1.summary.scalar('loss_between_class', between_class_loss)
+            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, between_class_loss)
 
             # Within
-            within_class_loss = tf.reduce_mean(d_positive)
-            tf.summary.scalar('loss_within_class', within_class_loss)
-            tf.add_to_collection(tf.GraphKeys.LOSSES, within_class_loss)
+            within_class_loss = tf.reduce_mean(input_tensor=d_positive)
+            tf.compat.v1.summary.scalar('loss_within_class', within_class_loss)
+            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, within_class_loss)
 
             # Total loss
             loss = tf.reduce_mean(
-                tf.maximum(basic_loss, 0.0), 0, name="total_loss")
-            tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
-            tf.summary.scalar('loss_triplet', loss)
+                input_tensor=tf.maximum(basic_loss, 0.0), axis=0, name="total_loss")
+            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, loss)
+            tf.compat.v1.summary.scalar('loss_triplet', loss)
 
         return loss
 
@@ -77,7 +77,7 @@ def triplet_loss(anchor_embedding,
 def triplet_fisher_loss(anchor_embedding, positive_embedding,
                         negative_embedding):
 
-    with tf.name_scope("triplet_loss"):
+    with tf.compat.v1.name_scope("triplet_loss"):
         # Normalize
         anchor_embedding = tf.nn.l2_normalize(
             anchor_embedding, 1, 1e-10, name="anchor")
@@ -86,9 +86,9 @@ def triplet_fisher_loss(anchor_embedding, positive_embedding,
         negative_embedding = tf.nn.l2_normalize(
             negative_embedding, 1, 1e-10, name="negative")
 
-        average_class = tf.reduce_mean(anchor_embedding, 0)
-        average_total = tf.div(tf.add(tf.reduce_mean(anchor_embedding, axis=0),\
-                        tf.reduce_mean(negative_embedding, axis=0)), 2)
+        average_class = tf.reduce_mean(input_tensor=anchor_embedding, axis=0)
+        average_total = tf.compat.v1.div(tf.add(tf.reduce_mean(input_tensor=anchor_embedding, axis=0),\
+                        tf.reduce_mean(input_tensor=negative_embedding, axis=0)), 2)
 
         length = anchor_embedding.get_shape().as_list()[0]
         dim = anchor_embedding.get_shape().as_list()[1]
@@ -121,9 +121,9 @@ def triplet_fisher_loss(anchor_embedding, positive_embedding,
         # Sw = tf.trace(Sw)
         # Sb = tf.trace(Sb)
         #loss = tf.trace(tf.div(Sb, Sw))
-        loss = tf.trace(tf.div(Sw, Sb), name=tf.GraphKeys.LOSSES)
+        loss = tf.linalg.trace(tf.compat.v1.div(Sw, Sb), name=tf.compat.v1.GraphKeys.LOSSES)
 
-        return loss, tf.trace(Sb), tf.trace(Sw)
+        return loss, tf.linalg.trace(Sb), tf.linalg.trace(Sw)
 
 
 def triplet_average_loss(anchor_embedding,
@@ -155,7 +155,7 @@ def triplet_average_loss(anchor_embedding,
 
     """
 
-    with tf.name_scope("triplet_loss"):
+    with tf.compat.v1.name_scope("triplet_loss"):
         # Normalize
         anchor_embedding = tf.nn.l2_normalize(
             anchor_embedding, 1, 1e-10, name="anchor")
@@ -164,17 +164,17 @@ def triplet_average_loss(anchor_embedding,
         negative_embedding = tf.nn.l2_normalize(
             negative_embedding, 1, 1e-10, name="negative")
 
-        anchor_mean = tf.reduce_mean(anchor_embedding, 0)
+        anchor_mean = tf.reduce_mean(input_tensor=anchor_embedding, axis=0)
 
         d_positive = tf.reduce_sum(
-            tf.square(tf.subtract(anchor_mean, positive_embedding)), 1)
+            input_tensor=tf.square(tf.subtract(anchor_mean, positive_embedding)), axis=1)
         d_negative = tf.reduce_sum(
-            tf.square(tf.subtract(anchor_mean, negative_embedding)), 1)
+            input_tensor=tf.square(tf.subtract(anchor_mean, negative_embedding)), axis=1)
 
         basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin)
         loss = tf.reduce_mean(
-            tf.maximum(basic_loss, 0.0), 0, name=tf.GraphKeys.LOSSES)
+            input_tensor=tf.maximum(basic_loss, 0.0), axis=0, name=tf.compat.v1.GraphKeys.LOSSES)
 
-        return loss, tf.reduce_mean(d_negative), tf.reduce_mean(d_positive)
+        return loss, tf.reduce_mean(input_tensor=d_negative), tf.reduce_mean(input_tensor=d_positive)
 
 
diff --git a/bob/learn/tensorflow/loss/center_loss.py b/bob/learn/tensorflow/loss/center_loss.py
index 00494387c11fb4c13bacc0b4d43e374f34ae7b01..553f01e95926e28a9c44935ee56200eaa0a915d9 100644
--- a/bob/learn/tensorflow/loss/center_loss.py
+++ b/bob/learn/tensorflow/loss/center_loss.py
@@ -13,25 +13,25 @@ class CenterLoss:
         self.n_features = n_features
         self.alpha = alpha
         self.name = name
-        with tf.variable_scope(self.name):
-            self.centers = tf.get_variable(
+        with tf.compat.v1.variable_scope(self.name):
+            self.centers = tf.compat.v1.get_variable(
                 "centers",
                 [n_classes, n_features],
                 dtype=tf.float32,
-                initializer=tf.constant_initializer(0.),
+                initializer=tf.compat.v1.constant_initializer(0.),
                 trainable=False,
             )
 
     def __call__(self, sparse_labels, prelogits):
-        with tf.name_scope(self.name):
+        with tf.compat.v1.name_scope(self.name):
             centers_batch = tf.gather(self.centers, sparse_labels)
             diff = (1 - self.alpha) * (centers_batch - prelogits)
-            self.centers_update_op = tf.scatter_sub(self.centers, sparse_labels, diff)
-            center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))
-        tf.summary.scalar("loss_center", center_loss)
+            self.centers_update_op = tf.compat.v1.scatter_sub(self.centers, sparse_labels, diff)
+            center_loss = tf.reduce_mean(input_tensor=tf.square(prelogits - centers_batch))
+        tf.compat.v1.summary.scalar("loss_center", center_loss)
         # Add histogram for all centers
         for i in range(self.n_classes):
-            tf.summary.histogram(f"center_{i}", self.centers[i])
+            tf.compat.v1.summary.histogram(f"center_{i}", self.centers[i])
         return center_loss
 
     @property
diff --git a/bob/learn/tensorflow/loss/epsc.py b/bob/learn/tensorflow/loss/epsc.py
index cfadb012ca54b73d92ffbab3ab63faaf4345db65..05627d65b21e60287b20b8bc5e6878f87d6f1604 100644
--- a/bob/learn/tensorflow/loss/epsc.py
+++ b/bob/learn/tensorflow/loss/epsc.py
@@ -10,76 +10,76 @@ def logits_loss(
     bio_logits, pad_logits, bio_labels, pad_labels, bio_loss, pad_loss, alpha=0.5
 ):
 
-    with tf.name_scope("Bio_loss"):
+    with tf.compat.v1.name_scope("Bio_loss"):
         bio_loss_ = bio_loss(logits=bio_logits, labels=bio_labels)
 
-    with tf.name_scope("PAD_loss"):
+    with tf.compat.v1.name_scope("PAD_loss"):
         pad_loss_ = pad_loss(
             logits=pad_logits, labels=tf.cast(pad_labels, dtype="int32")
         )
 
-    with tf.name_scope("EPSC_loss"):
+    with tf.compat.v1.name_scope("EPSC_loss"):
         total_loss = (1 - alpha) * bio_loss_ + alpha * pad_loss_
 
-    tf.add_to_collection(tf.GraphKeys.LOSSES, bio_loss_)
-    tf.add_to_collection(tf.GraphKeys.LOSSES, pad_loss_)
-    tf.add_to_collection(tf.GraphKeys.LOSSES, total_loss)
+    tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, bio_loss_)
+    tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, pad_loss_)
+    tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, total_loss)
 
-    tf.summary.scalar("bio_loss", bio_loss_)
-    tf.summary.scalar("pad_loss", pad_loss_)
-    tf.summary.scalar("epsc_loss", total_loss)
+    tf.compat.v1.summary.scalar("bio_loss", bio_loss_)
+    tf.compat.v1.summary.scalar("pad_loss", pad_loss_)
+    tf.compat.v1.summary.scalar("epsc_loss", total_loss)
 
     return total_loss
 
 
 def embedding_norm_loss(prelogits_left, prelogits_right, b, c, margin=10.0):
-    with tf.name_scope("embedding_norm_loss"):
+    with tf.compat.v1.name_scope("embedding_norm_loss"):
         prelogits_left = norm(prelogits_left)
         prelogits_right = norm(prelogits_right)
 
         loss = tf.add_n(
             [
-                tf.reduce_mean(b * (tf.maximum(prelogits_left - margin, 0))),
-                tf.reduce_mean((1 - b) * (tf.maximum(2 * margin - prelogits_left, 0))),
-                tf.reduce_mean(c * (tf.maximum(prelogits_right - margin, 0))),
-                tf.reduce_mean((1 - c) * (tf.maximum(2 * margin - prelogits_right, 0))),
+                tf.reduce_mean(input_tensor=b * (tf.maximum(prelogits_left - margin, 0))),
+                tf.reduce_mean(input_tensor=(1 - b) * (tf.maximum(2 * margin - prelogits_left, 0))),
+                tf.reduce_mean(input_tensor=c * (tf.maximum(prelogits_right - margin, 0))),
+                tf.reduce_mean(input_tensor=(1 - c) * (tf.maximum(2 * margin - prelogits_right, 0))),
             ],
             name="embedding_norm_loss",
         )
-        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
-        tf.summary.scalar("embedding_norm_loss", loss)
+        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, loss)
+        tf.compat.v1.summary.scalar("embedding_norm_loss", loss)
         # log norm of embeddings for BF and PA separately to see how their norm
         # evolves over time
         bf_norm = tf.concat(
             [
-                tf.gather(prelogits_left, tf.where(b > 0.5)),
-                tf.gather(prelogits_right, tf.where(c > 0.5)),
+                tf.gather(prelogits_left, tf.compat.v1.where(b > 0.5)),
+                tf.gather(prelogits_right, tf.compat.v1.where(c > 0.5)),
             ],
             axis=0,
         )
         pa_norm = tf.concat(
             [
-                tf.gather(prelogits_left, tf.where(b < 0.5)),
-                tf.gather(prelogits_right, tf.where(c < 0.5)),
+                tf.gather(prelogits_left, tf.compat.v1.where(b < 0.5)),
+                tf.gather(prelogits_right, tf.compat.v1.where(c < 0.5)),
             ],
             axis=0,
         )
-        tf.summary.histogram("BF_embeddings_norm", bf_norm)
-        tf.summary.histogram("PA_embeddings_norm", pa_norm)
+        tf.compat.v1.summary.histogram("BF_embeddings_norm", bf_norm)
+        tf.compat.v1.summary.histogram("PA_embeddings_norm", pa_norm)
     return loss
 
 
 def siamese_loss(bio_logits, pad_logits, bio_labels, pad_labels, alpha=0.1):
     # prepare a, b, c
-    with tf.name_scope("epsc_labels"):
-        a = tf.to_float(
-            tf.math.equal(bio_labels["left"], bio_labels["right"]), name="a"
+    with tf.compat.v1.name_scope("epsc_labels"):
+        a = tf.cast(
+            tf.math.equal(bio_labels["left"], bio_labels["right"]), dtype=tf.float32, name="a"
         )
-        b = tf.to_float(tf.math.equal(pad_labels["left"], True), name="b")
-        c = tf.to_float(tf.math.equal(pad_labels["right"], True), name="c")
-        tf.summary.scalar("Mean_a", tf.reduce_mean(a))
-        tf.summary.scalar("Mean_b", tf.reduce_mean(b))
-        tf.summary.scalar("Mean_c", tf.reduce_mean(c))
+        b = tf.cast(tf.math.equal(pad_labels["left"], True), name="b", dtype=tf.float32)
+        c = tf.cast(tf.math.equal(pad_labels["right"], True), name="c", dtype=tf.float32)
+        tf.compat.v1.summary.scalar("Mean_a", tf.reduce_mean(input_tensor=a))
+        tf.compat.v1.summary.scalar("Mean_b", tf.reduce_mean(input_tensor=b))
+        tf.compat.v1.summary.scalar("Mean_c", tf.reduce_mean(input_tensor=c))
 
     prelogits_left = bio_logits["left"]
     prelogits_right = bio_logits["right"]
@@ -88,11 +88,11 @@ def siamese_loss(bio_logits, pad_logits, bio_labels, pad_labels, alpha=0.1):
 
     pad_loss = alpha * embedding_norm_loss(prelogits_left, prelogits_right, b, c)
 
-    with tf.name_scope("epsc_loss"):
+    with tf.compat.v1.name_scope("epsc_loss"):
         epsc_loss = (1 - alpha) * bio_loss + alpha * pad_loss
-        tf.add_to_collection(tf.GraphKeys.LOSSES, epsc_loss)
+        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, epsc_loss)
 
-    tf.summary.scalar("epsc_loss", epsc_loss)
+    tf.compat.v1.summary.scalar("epsc_loss", epsc_loss)
 
     return epsc_loss
 
@@ -106,7 +106,7 @@ def py_eer(negatives, positives):
     negatives = tf.reshape(tf.cast(negatives, "float64"), [-1])
     positives = tf.reshape(tf.cast(positives, "float64"), [-1])
 
-    eer = tf.py_func(_eer, [negatives, positives], tf.float64, name="py_eer")
+    eer = tf.compat.v1.py_func(_eer, [negatives, positives], tf.float64, name="py_eer")
 
     return tf.cast(eer, "float32")
 
@@ -121,7 +121,7 @@ def epsc_metric(
 ):
     # math.exp(-2.0) = 0.1353352832366127
     # math.exp(-15.0) = 3.059023205018258e-07
-    with tf.name_scope("epsc_metrics"):
+    with tf.compat.v1.name_scope("epsc_metrics"):
         bio_predictions_op = predict_using_tensors(
             bio_embeddings, bio_labels, num=batch_size
         )
@@ -153,24 +153,24 @@ def epsc_metric(
         # update_op = tf.assign_add(pad_accuracy, tf.cast(acc, tf.float32))
         # update_op = tf.group([update_op] + print_ops)
 
-        tp = tf.metrics.true_positives_at_thresholds(
+        tp = tf.compat.v1.metrics.true_positives_at_thresholds(
             pad_labels, pad_probabilities, [pad_threshold]
         )
-        fp = tf.metrics.false_positives_at_thresholds(
+        fp = tf.compat.v1.metrics.false_positives_at_thresholds(
             pad_labels, pad_probabilities, [pad_threshold]
         )
-        tn = tf.metrics.true_negatives_at_thresholds(
+        tn = tf.compat.v1.metrics.true_negatives_at_thresholds(
             pad_labels, pad_probabilities, [pad_threshold]
         )
-        fn = tf.metrics.false_negatives_at_thresholds(
+        fn = tf.compat.v1.metrics.false_negatives_at_thresholds(
             pad_labels, pad_probabilities, [pad_threshold]
         )
         pad_accuracy = (tp[0] + tn[0]) / (tp[0] + tn[0] + fp[0] + fn[0])
-        pad_accuracy = tf.reduce_mean(pad_accuracy)
+        pad_accuracy = tf.reduce_mean(input_tensor=pad_accuracy)
         pad_update_ops = tf.group([x[1] for x in (tp, tn, fp, fn)])
 
         eval_metric_ops = {
-            "bio_accuracy": tf.metrics.accuracy(
+            "bio_accuracy": tf.compat.v1.metrics.accuracy(
                 labels=bio_labels, predictions=bio_predictions_op
             ),
             "pad_accuracy": (pad_accuracy, pad_update_ops),
diff --git a/bob/learn/tensorflow/loss/mmd.py b/bob/learn/tensorflow/loss/mmd.py
index 2933d7b1d3cd32b7533c2fa5213e38eeb7192965..2a0efff5fcb68cbda78d8c5e8b1a6d5c4159bf0f 100644
--- a/bob/learn/tensorflow/loss/mmd.py
+++ b/bob/learn/tensorflow/loss/mmd.py
@@ -4,9 +4,9 @@ import tensorflow as tf
 def compute_kernel(x, y):
     """Gaussian kernel.
     """
-    x_size = tf.shape(x)[0]
-    y_size = tf.shape(y)[0]
-    dim = tf.shape(x)[1]
+    x_size = tf.shape(input=x)[0]
+    y_size = tf.shape(input=y)[0]
+    dim = tf.shape(input=x)[1]
     tiled_x = tf.tile(
         tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
     )
@@ -14,7 +14,7 @@ def compute_kernel(x, y):
         tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
     )
     return tf.exp(
-        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
+        -tf.reduce_mean(input_tensor=tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
     )
 
 
@@ -26,7 +26,7 @@ def mmd(x, y):
     y_kernel = compute_kernel(y, y)
     xy_kernel = compute_kernel(x, y)
     return (
-        tf.reduce_mean(x_kernel)
-        + tf.reduce_mean(y_kernel)
-        - 2 * tf.reduce_mean(xy_kernel)
+        tf.reduce_mean(input_tensor=x_kernel)
+        + tf.reduce_mean(input_tensor=y_kernel)
+        - 2 * tf.reduce_mean(input_tensor=xy_kernel)
     )
diff --git a/bob/learn/tensorflow/loss/pairwise_confusion.py b/bob/learn/tensorflow/loss/pairwise_confusion.py
index 155b1a299625283ade232af1c84d13e146cceff1..b4c319711733e0631cece2bf25a7e75dcbac687e 100644
--- a/bob/learn/tensorflow/loss/pairwise_confusion.py
+++ b/bob/learn/tensorflow/loss/pairwise_confusion.py
@@ -8,9 +8,9 @@ def total_pairwise_confusion(prelogits, name=None):
         Representations for Face Anti-Spoofing,” arXiv preprint arXiv:1901.05602, 2019.
     """
     # compute L2 norm between all prelogits and sum them.
-    with tf.name_scope(name, default_name="total_pairwise_confusion"):
-        prelogits = tf.reshape(prelogits, (tf.shape(prelogits)[0], -1))
-        loss_tpc = tf.reduce_mean(upper_triangle(pdist_safe(prelogits)))
+    with tf.compat.v1.name_scope(name, default_name="total_pairwise_confusion"):
+        prelogits = tf.reshape(prelogits, (tf.shape(input=prelogits)[0], -1))
+        loss_tpc = tf.reduce_mean(input_tensor=upper_triangle(pdist_safe(prelogits)))
 
-    tf.summary.scalar("loss_tpc", loss_tpc)
+    tf.compat.v1.summary.scalar("loss_tpc", loss_tpc)
     return loss_tpc
diff --git a/bob/learn/tensorflow/loss/pixel_wise.py b/bob/learn/tensorflow/loss/pixel_wise.py
index b34695045c20273bdc0063a928ceb723324eca6d..4a11e65630abc7c2113853ef3d733595dff33875 100644
--- a/bob/learn/tensorflow/loss/pixel_wise.py
+++ b/bob/learn/tensorflow/loss/pixel_wise.py
@@ -18,7 +18,7 @@ class PixelWise:
         self.label_smoothing = label_smoothing
 
     def __call__(self, labels, logits):
-        with tf.name_scope("PixelWiseLoss"):
+        with tf.compat.v1.name_scope("PixelWiseLoss"):
             flatten = tf.keras.layers.Flatten()
             logits = flatten(logits)
             n_pixels = logits.get_shape()[-1]
@@ -45,19 +45,19 @@ class PixelWise:
                 # reshape logits too as softmax_cross_entropy is buggy and cannot really
                 # handle higher dimensions
                 logits = tf.reshape(logits, (-1, self.n_one_hot_labels))
-                loss_fn = tf.losses.softmax_cross_entropy
+                loss_fn = tf.compat.v1.losses.softmax_cross_entropy
             else:
                 labels = tf.reshape(labels, (-1, 1))
                 labels = tf_repeat(labels, [n_pixels, 1])
                 labels = tf.reshape(labels, (-1, n_pixels))
-                loss_fn = tf.losses.sigmoid_cross_entropy
+                loss_fn = tf.compat.v1.losses.sigmoid_cross_entropy
 
             loss_pixel_wise = loss_fn(
                 labels,
                 logits=logits,
                 weights=weights,
                 label_smoothing=self.label_smoothing,
-                reduction=tf.losses.Reduction.MEAN,
+                reduction=tf.compat.v1.losses.Reduction.MEAN,
             )
-        tf.summary.scalar("loss_pixel_wise", loss_pixel_wise)
+        tf.compat.v1.summary.scalar("loss_pixel_wise", loss_pixel_wise)
         return loss_pixel_wise
diff --git a/bob/learn/tensorflow/loss/utils.py b/bob/learn/tensorflow/loss/utils.py
index 8c0e9eeaa7d1f059216f441017469c2701531426..aad477ed5747cfe102015e77742ce93220073cf4 100644
--- a/bob/learn/tensorflow/loss/utils.py
+++ b/bob/learn/tensorflow/loss/utils.py
@@ -82,11 +82,11 @@ def balanced_softmax_cross_entropy_loss_weights(labels, dtype="float32"):
     >>> #weights = balanced_softmax_cross_entropy_loss_weights(labels, dtype=logits.dtype)
     >>> #loss = tf.losses.softmax_cross_entropy(logits=logits, labels=labels, weights=weights)
     """
-    shape = tf.cast(tf.shape(labels), dtype=dtype)
+    shape = tf.cast(tf.shape(input=labels), dtype=dtype)
     batch_size, n_classes = shape[0], shape[1]
-    weights = tf.cast(tf.reduce_sum(labels, axis=0), dtype=dtype)
+    weights = tf.cast(tf.reduce_sum(input_tensor=labels, axis=0), dtype=dtype)
     weights = batch_size / weights / n_classes
-    weights = tf.gather(weights, tf.argmax(labels, axis=1))
+    weights = tf.gather(weights, tf.argmax(input=labels, axis=1))
     return weights
 
 
@@ -136,9 +136,9 @@ def balanced_sigmoid_cross_entropy_loss_weights(labels, dtype="float32"):
     >>> #loss = tf.losses.sigmoid_cross_entropy(logits=logits, labels=labels, weights=weights)
     """
     labels = tf.cast(labels, dtype='int32')
-    batch_size = tf.cast(tf.shape(labels)[0], dtype=dtype)
-    weights = tf.cast(tf.reduce_sum(labels), dtype=dtype)
-    weights = tf.convert_to_tensor([batch_size - weights, weights])
+    batch_size = tf.cast(tf.shape(input=labels)[0], dtype=dtype)
+    weights = tf.cast(tf.reduce_sum(input_tensor=labels), dtype=dtype)
+    weights = tf.convert_to_tensor(value=[batch_size - weights, weights])
     weights = batch_size / weights / 2
     weights = tf.gather(weights, labels)
     return weights
diff --git a/bob/learn/tensorflow/loss/vat.py b/bob/learn/tensorflow/loss/vat.py
index b48f4f8918287a68e71f1647d22cbe1a4a0d2c52..d77ecf605bce000841734b2d5dae8e15da7febf4 100644
--- a/bob/learn/tensorflow/loss/vat.py
+++ b/bob/learn/tensorflow/loss/vat.py
@@ -28,27 +28,27 @@ from functools import partial
 
 
 def get_normalized_vector(d):
-    d /= (1e-12 + tf.reduce_max(tf.abs(d), list(range(1, len(d.get_shape()))), keepdims=True))
-    d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), list(range(1, len(d.get_shape()))), keepdims=True))
+    d /= (1e-12 + tf.reduce_max(input_tensor=tf.abs(d), axis=list(range(1, len(d.get_shape()))), keepdims=True))
+    d /= tf.sqrt(1e-6 + tf.reduce_sum(input_tensor=tf.pow(d, 2.0), axis=list(range(1, len(d.get_shape()))), keepdims=True))
     return d
 
 
 def logsoftmax(x):
-    xdev = x - tf.reduce_max(x, 1, keepdims=True)
-    lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keepdims=True))
+    xdev = x - tf.reduce_max(input_tensor=x, axis=1, keepdims=True)
+    lsm = xdev - tf.math.log(tf.reduce_sum(input_tensor=tf.exp(xdev), axis=1, keepdims=True))
     return lsm
 
 
 def kl_divergence_with_logit(q_logit, p_logit):
     q = tf.nn.softmax(q_logit)
-    qlogq = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(q_logit), 1))
-    qlogp = tf.reduce_mean(tf.reduce_sum(q * logsoftmax(p_logit), 1))
+    qlogq = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=q * logsoftmax(q_logit), axis=1))
+    qlogp = tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=q * logsoftmax(p_logit), axis=1))
     return qlogq - qlogp
 
 
 def entropy_y_x(logit):
     p = tf.nn.softmax(logit)
-    return -tf.reduce_mean(tf.reduce_sum(p * logsoftmax(logit), 1))
+    return -tf.reduce_mean(input_tensor=tf.reduce_sum(input_tensor=p * logsoftmax(logit), axis=1))
 
 
 class VATLoss:
@@ -106,16 +106,16 @@ class VATLoss:
         if mode != tf.estimator.ModeKeys.TRAIN:
             return 0.
         architecture = partial(architecture, reuse=True)
-        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
+        with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope(), reuse=True):
             vat_loss = self.virtual_adversarial_loss(features, logits, architecture, mode)
-            tf.summary.scalar("loss_VAT", vat_loss)
-            tf.add_to_collection(tf.GraphKeys.LOSSES, vat_loss)
+            tf.compat.v1.summary.scalar("loss_VAT", vat_loss)
+            tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, vat_loss)
             if self.method == 'vat':
                 loss = vat_loss
             elif self.method == 'vatent':
                 ent_loss = entropy_y_x(logits)
-                tf.summary.scalar("loss_entropy", ent_loss)
-                tf.add_to_collection(tf.GraphKeys.LOSSES, ent_loss)
+                tf.compat.v1.summary.scalar("loss_entropy", ent_loss)
+                tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, ent_loss)
                 loss = vat_loss + ent_loss
             else:
                 raise ValueError
@@ -125,20 +125,20 @@ class VATLoss:
         r_vadv = self.generate_virtual_adversarial_perturbation(features, logits, architecture, mode)
         logit_p = tf.stop_gradient(logits)
         adversarial_input = features + r_vadv
-        tf.summary.image("Adversarial_Image", adversarial_input)
+        tf.compat.v1.summary.image("Adversarial_Image", adversarial_input)
         logit_m = architecture(adversarial_input, mode=mode)[0]
         loss = kl_divergence_with_logit(logit_p, logit_m)
         return tf.identity(loss, name=name)
 
     def generate_virtual_adversarial_perturbation(self, features, logits, architecture, mode):
-        d = tf.random_normal(shape=tf.shape(features))
+        d = tf.random.normal(shape=tf.shape(input=features))
 
         for _ in range(self.num_power_iterations):
             d = self.xi * get_normalized_vector(d)
             logit_p = logits
             logit_m = architecture(features + d, mode=mode)[0]
             dist = kl_divergence_with_logit(logit_p, logit_m)
-            grad = tf.gradients(dist, [d], aggregation_method=2)[0]
+            grad = tf.gradients(ys=dist, xs=[d], aggregation_method=2)[0]
             d = tf.stop_gradient(grad)
 
         return self.epsilon * get_normalized_vector(d)
diff --git a/bob/learn/tensorflow/models/inception_resnet_v2.py b/bob/learn/tensorflow/models/inception_resnet_v2.py
index 79b1a66d24f59a3788742fe843e08466e3a67bbd..e5711e27ff6e075b271e7f97d7239ea75ac4809b 100644
--- a/bob/learn/tensorflow/models/inception_resnet_v2.py
+++ b/bob/learn/tensorflow/models/inception_resnet_v2.py
@@ -574,7 +574,7 @@ def MultiScaleInceptionResNetV2(
     padding = "SAME" if align_feature_maps else "VALID"
     name = name or "InceptionResnetV2"
 
-    with tf.name_scope(name, "InceptionResnetV2", [img_input]):
+    with tf.compat.v1.name_scope(name, "InceptionResnetV2", [img_input]):
         # convert colors from RGB to a learned color space and batch norm inputs
         # 224, 224, 4
         net = Conv2D_BN(
diff --git a/bob/learn/tensorflow/script/cache_dataset.py b/bob/learn/tensorflow/script/cache_dataset.py
index 7f8c04f39dceaf9a312ed655d1a17882431058e2..e38d7c8a19068efe57a166b22894dcd57ef92123 100644
--- a/bob/learn/tensorflow/script/cache_dataset.py
+++ b/bob/learn/tensorflow/script/cache_dataset.py
@@ -46,13 +46,13 @@ def cache_dataset(input_fn, mode, **kwargs):
         logger.info("cache_only as True will be passed to input_fn.")
 
     # call the input function manually
-    with tf.Session() as sess:
+    with tf.compat.v1.Session() as sess:
         data = input_fn(mode, **kwargs)
         if isinstance(data, tf.data.Dataset):
-            iterator = data.make_initializable_iterator()
+            iterator = tf.compat.v1.data.make_initializable_iterator(data)
             data = iterator.get_next()
             sess.run(iterator.initializer)
-        sess.run(tf.initializers.global_variables())
+        sess.run(tf.compat.v1.initializers.global_variables())
         try:
             while True:
                 sess.run(data)
diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py
index 7460f7c664ec2726bdeae2bec5ca4bcfc0e11d4a..3c01c3421381890924087cca059a440d20c13247 100644
--- a/bob/learn/tensorflow/script/db_to_tfrecords.py
+++ b/bob/learn/tensorflow/script/db_to_tfrecords.py
@@ -154,7 +154,7 @@ def db_to_tfrecords(
 
     n_samples = len(samples)
     sample_count = 0
-    with tf.python_io.TFRecordWriter(output) as writer:
+    with tf.io.TFRecordWriter(output) as writer:
         if shuffle:
             logger.info("Shuffling the samples before writing ...")
             random.shuffle(samples)
@@ -258,7 +258,7 @@ def datasets_to_tfrecords(dataset, output, force, **kwargs):
         return
 
     click.echo("Writing tfrecod to: {}".format(output))
-    with tf.Session() as sess:
+    with tf.compat.v1.Session() as sess:
         os.makedirs(os.path.dirname(output), exist_ok=True)
         try:
             sess.run(dataset_to_tfrecord(dataset, output))
@@ -297,9 +297,9 @@ def dataset_to_hdf5(dataset, output, mean, **kwargs):
     """
     log_parameters(logger)
 
-    data, label, key = dataset.make_one_shot_iterator().get_next()
+    data, label, key = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
 
-    sess = tf.Session()
+    sess = tf.compat.v1.Session()
 
     extension = ".hdf5"
 
diff --git a/bob/learn/tensorflow/script/trim.py b/bob/learn/tensorflow/script/trim.py
index d89bb292a75e75b191efc7e9542a083d4e9d9c45..ed29b4c28847499a283d9f9908ba427f8604d8b6 100644
--- a/bob/learn/tensorflow/script/trim.py
+++ b/bob/learn/tensorflow/script/trim.py
@@ -39,7 +39,7 @@ def delete_extra_checkpoints(directory, keep_last_n, dry_run):
     all_paths = filter(_existing, ckpt.all_model_checkpoint_paths)
     all_paths = list(map(os.path.basename, all_paths))
     model_checkpoint_path = os.path.basename(ckpt.model_checkpoint_path)
-    tf.train.update_checkpoint_state(
+    tf.compat.v1.train.update_checkpoint_state(
         directory, model_checkpoint_path, all_paths)
 
 
diff --git a/bob/learn/tensorflow/script/utils.py b/bob/learn/tensorflow/script/utils.py
index 12b360c00dff65ca02d8825ce95fcd68d5b14da3..ca5ba4ef219b093b4706729a37bbc55b71bba508 100644
--- a/bob/learn/tensorflow/script/utils.py
+++ b/bob/learn/tensorflow/script/utils.py
@@ -14,7 +14,7 @@ def eager_execution_option(**kwargs):
             if not value or ctx.resilient_parsing:
                 return
             import tensorflow as tf
-            tf.enable_eager_execution()
+            tf.compat.v1.enable_eager_execution()
             if not tf.executing_eagerly():
                 raise click.ClickException(
                     "Could not enable tensorflow eager execution mode!")
diff --git a/bob/learn/tensorflow/test/data/input_biogenerator_config.py b/bob/learn/tensorflow/test/data/input_biogenerator_config.py
index e4516be98af48040233f26bf42f61c11deec3899..2aca7ccd4cda182bc882f1ca512b65b837a18810 100644
--- a/bob/learn/tensorflow/test/data/input_biogenerator_config.py
+++ b/bob/learn/tensorflow/test/data/input_biogenerator_config.py
@@ -50,7 +50,7 @@ def input_fn(mode):
         dataset = dataset.repeat(epochs)
     dataset = dataset.batch(batch_size)
 
-    data, label, key = dataset.make_one_shot_iterator().get_next()
+    data, label, key = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
     return {'data': data, 'key': key}, label
 
 
diff --git a/bob/learn/tensorflow/test/data/input_predict_bio_config.py b/bob/learn/tensorflow/test/data/input_predict_bio_config.py
index 2b8687f77bde2f024b213df0a1b49bc2231d5867..d355768b5a84064730878326f879d8f02dd77f2d 100644
--- a/bob/learn/tensorflow/test/data/input_predict_bio_config.py
+++ b/bob/learn/tensorflow/test/data/input_predict_bio_config.py
@@ -11,7 +11,7 @@ def bio_predict_input_fn(generator, output_types, output_shapes):
         # even further if you want.
         dataset = dataset.prefetch(1)
         dataset = dataset.batch(10**3)
-        images, labels, keys = dataset.make_one_shot_iterator().get_next()
+        images, labels, keys = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
 
         return {'data': images, 'key': keys}, labels
     return input_fn
diff --git a/bob/learn/tensorflow/test/data/mnist_estimator.py b/bob/learn/tensorflow/test/data/mnist_estimator.py
index 224e974c2d8dc9627fd38e0a8fead61318f3b425..9957657681b9c46594647c68302eb592ceadadb4 100644
--- a/bob/learn/tensorflow/test/data/mnist_estimator.py
+++ b/bob/learn/tensorflow/test/data/mnist_estimator.py
@@ -1,3 +1,3 @@
 import tensorflow as tf
 data = tf.feature_column.numeric_column('data', shape=[784])
-estimator = tf.estimator.LinearClassifier(feature_columns=[data], n_classes=10)
+estimator = tf.estimator.LinearClassifier(feature_columns=[data], n_classes=10, loss_reduction=tf.keras.losses.Reduction.SUM)
diff --git a/bob/learn/tensorflow/test/data/mnist_input_fn.py b/bob/learn/tensorflow/test/data/mnist_input_fn.py
index e5bf1f4a058a96b529613a879e396c66a410afa6..0274f5075d5dede7e5d7a98538e05613089733bc 100644
--- a/bob/learn/tensorflow/test/data/mnist_input_fn.py
+++ b/bob/learn/tensorflow/test/data/mnist_input_fn.py
@@ -14,7 +14,7 @@ def input_fn(mode):
         num_epochs = 1
         shuffle = True
     data, labels = database.data(groups=groups)
-    return tf.estimator.inputs.numpy_input_fn(
+    return tf.compat.v1.estimator.inputs.numpy_input_fn(
         x={
             "data": data.astype('float32'),
             'key': labels.astype('float32')
diff --git a/bob/learn/tensorflow/test/data/train_scripts/siamese.py b/bob/learn/tensorflow/test/data/train_scripts/siamese.py
index 3658e71f5103b4a4f7e8f7603dcf51413165edca..57d60a4cf1ab52b0c9a7d0446f2bff645638f395 100644
--- a/bob/learn/tensorflow/test/data/train_scripts/siamese.py
+++ b/bob/learn/tensorflow/test/data/train_scripts/siamese.py
@@ -33,7 +33,7 @@ loss = ContrastiveLoss(contrastive_margin=4.)
 learning_rate = constant(base_learning_rate=0.01)
 
 ### SOLVER ###
-optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
 
 ### Trainer ###
 trainer = Trainer
diff --git a/bob/learn/tensorflow/test/data/train_scripts/softmax.py b/bob/learn/tensorflow/test/data/train_scripts/softmax.py
index ed48ececfe66b19d7d5d83dbed6166de0d75f90e..4701f0496d9210b0ab96cd923980ace6cdc1f738 100644
--- a/bob/learn/tensorflow/test/data/train_scripts/softmax.py
+++ b/bob/learn/tensorflow/test/data/train_scripts/softmax.py
@@ -32,7 +32,7 @@ loss = MeanSoftMaxLoss()
 learning_rate = constant(base_learning_rate=0.01)
 
 ### SOLVER ###
-optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
 
 ### Trainer ###
 trainer = Trainer
diff --git a/bob/learn/tensorflow/test/data/train_scripts/triplet.py b/bob/learn/tensorflow/test/data/train_scripts/triplet.py
index 35908c8e305eb64359f183d3fd74afee541a8a01..3fcef85d0dffb576c429196d69c24413bad2593d 100644
--- a/bob/learn/tensorflow/test/data/train_scripts/triplet.py
+++ b/bob/learn/tensorflow/test/data/train_scripts/triplet.py
@@ -29,7 +29,7 @@ loss = TripletLoss(margin=4.)
 learning_rate = constant(base_learning_rate=0.01)
 
 ### SOLVER ###
-optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
 
 ### Trainer ###
 trainer = Trainer
diff --git a/bob/learn/tensorflow/test/test_dataset.py b/bob/learn/tensorflow/test/test_dataset.py
index 0c5fe8cf4458669e7be26d26b74d70e45cd778ee..348fc1eee93ddaf831f85c829b340c52f5a7a014 100644
--- a/bob/learn/tensorflow/test/test_dataset.py
+++ b/bob/learn/tensorflow/test/test_dataset.py
@@ -56,7 +56,7 @@ def test_siamese_dataset():
         per_image_normalization=False,
         output_shape=output_shape)
 
-    with tf.Session() as session:
+    with tf.compat.v1.Session() as session:
         d, l = session.run([data, label])
         assert len(l) == 2
         assert d['left'].shape == (2, 50, 50, 3)
@@ -72,7 +72,7 @@ def test_triplet_dataset():
         2,
         per_image_normalization=False,
         output_shape=output_shape)
-    with tf.Session() as session:
+    with tf.compat.v1.Session() as session:
         d = session.run([data])[0]
         assert len(d.keys()) == 3
         assert d['anchor'].shape == (2, 50, 50, 3)
@@ -90,11 +90,11 @@ def test_dataset_using_generator():
     shape = (2, 2, 1)    
     samples = [numpy.ones(shape, dtype="float32")*i for i in range(10)]
     
-    with tf.Session() as session: 
+    with tf.compat.v1.Session() as session: 
         dataset = dataset_using_generator(samples,\
                                           reader,\
                                           multiple_samples=True)
-        iterator = dataset.make_one_shot_iterator().get_next()
+        iterator = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
         for i in range(11):
             try:
                 sample = session.run(iterator)                
diff --git a/bob/learn/tensorflow/test/test_db_to_tfrecords.py b/bob/learn/tensorflow/test/test_db_to_tfrecords.py
index db5491d08066aa5e5138b447d7e32677576e1900..f78ddc1c3c2bcb679f78ba04513ced4f6cab7dea 100644
--- a/bob/learn/tensorflow/test/test_db_to_tfrecords.py
+++ b/bob/learn/tensorflow/test/test_db_to_tfrecords.py
@@ -22,21 +22,21 @@ dummy_config = pkg_resources.resource_filename(
 def compare_datasets(ds1, ds2, sess=None):
     if tf.executing_eagerly():
         for values1, values2 in zip(ds1, ds2):
-            values1 = tf.contrib.framework.nest.flatten(values1)
-            values2 = tf.contrib.framework.nest.flatten(values2)
+            values1 = tf.nest.flatten(values1)
+            values2 = tf.nest.flatten(values2)
             for v1, v2 in zip(values1, values2):
-                if not tf.reduce_all(tf.math.equal(v1, v2)):
+                if not tf.reduce_all(input_tensor=tf.math.equal(v1, v2)):
                     return False
     else:
-        ds1 = ds1.make_one_shot_iterator().get_next()
-        ds2 = ds2.make_one_shot_iterator().get_next()
+        ds1 = tf.compat.v1.data.make_one_shot_iterator(ds1).get_next()
+        ds2 = tf.compat.v1.data.make_one_shot_iterator(ds2).get_next()
         while True:
             try:
                 values1, values2 = sess.run([ds1, ds2])
             except tf.errors.OutOfRangeError:
                 break
-            values1 = tf.contrib.framework.nest.flatten(values1)
-            values2 = tf.contrib.framework.nest.flatten(values2)
+            values1 = tf.nest.flatten(values1)
+            values2 = tf.nest.flatten(values2)
             for v1, v2 in zip(values1, values2):
                 v1, v2 = np.asarray(v1), np.asarray(v2)
                 if not np.all(v1 == v2):
@@ -112,7 +112,7 @@ def test_datasets_to_tfrecords():
             datasets_to_tfrecords, args=args, standalone_mode=False)
         assert_click_runner_result(result)
         # read back the tfrecod
-        with tf.Session() as sess:
+        with tf.compat.v1.Session() as sess:
             dataset2 = dataset_from_tfrecord(output_path)
             dataset1 = load(
                 [dummy_config], attribute_name='dataset', entry_point_group='bob')
diff --git a/bob/learn/tensorflow/test/test_loss.py b/bob/learn/tensorflow/test/test_loss.py
index 92d94c8f1812512b23167be086807a0ffe65a0ce..f54d49e700129856d2517516d5758f6f45b27a70 100644
--- a/bob/learn/tensorflow/test/test_loss.py
+++ b/bob/learn/tensorflow/test/test_loss.py
@@ -42,7 +42,7 @@ def test_balanced_softmax_cross_entropy_loss_weights():
                           [0, 0, 1],
                           [1, 0, 0]], dtype="int32")
 
-    with tf.Session() as session:
+    with tf.compat.v1.Session() as session:
         weights = session.run(balanced_softmax_cross_entropy_loss_weights(labels))
  
     expected_weights = numpy.array([0.53333336, 0.53333336, 1.5238096 , 2.1333334,\
@@ -62,7 +62,7 @@ def test_balanced_sigmoid_cross_entropy_loss_weights():
     labels = numpy.array([1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0,
                           1, 1, 0, 1, 1, 1, 0, 1, 0, 1], dtype="int32")
     
-    with tf.Session() as session:
+    with tf.compat.v1.Session() as session:
         weights = session.run(balanced_sigmoid_cross_entropy_loss_weights(labels, dtype='float32'))
         
     expected_weights = numpy.array([0.8, 0.8, 1.3333334, 1.3333334, 1.3333334, 0.8,
diff --git a/bob/learn/tensorflow/test/test_utils.py b/bob/learn/tensorflow/test/test_utils.py
index 01af76e14b4ae866fdb5c2fb1a299e8ce3ff49d1..6edfe4519f74b02b160b343f52943e0d2e608670 100644
--- a/bob/learn/tensorflow/test/test_utils.py
+++ b/bob/learn/tensorflow/test/test_utils.py
@@ -56,8 +56,8 @@ def test_embedding_accuracy_tensors():
     data = numpy.vstack((class_a, class_b))
     labels = numpy.concatenate((labels_a, labels_b))
 
-    data = tf.convert_to_tensor(data.astype("float32"))
-    labels = tf.convert_to_tensor(labels.astype("int64"))
+    data = tf.convert_to_tensor(value=data.astype("float32"))
+    labels = tf.convert_to_tensor(value=labels.astype("int64"))
 
     accuracy = compute_embedding_accuracy_tensors(data, labels)
     assert accuracy == 1.
diff --git a/bob/learn/tensorflow/utils/eval.py b/bob/learn/tensorflow/utils/eval.py
index cf836f6e0d1f6742b10fee3d5fa9390d255f1dab..3efa80a1ee05446ed669680b374ce07e1232fe50 100644
--- a/bob/learn/tensorflow/utils/eval.py
+++ b/bob/learn/tensorflow/utils/eval.py
@@ -19,5 +19,5 @@ def get_global_step(path):
     global_step : int
         The global step number.
     """
-    checkpoint_reader = tf.train.NewCheckpointReader(path)
-    return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
+    checkpoint_reader = tf.compat.v1.train.NewCheckpointReader(path)
+    return checkpoint_reader.get_tensor(tf.compat.v1.GraphKeys.GLOBAL_STEP)
diff --git a/bob/learn/tensorflow/utils/graph.py b/bob/learn/tensorflow/utils/graph.py
index e434289d12a4fad5cfa5cfab41d4555fa39fc485..3f551dd7d230d12856773ec451812eda1ded52a2 100644
--- a/bob/learn/tensorflow/utils/graph.py
+++ b/bob/learn/tensorflow/utils/graph.py
@@ -33,8 +33,8 @@ def call_on_frozen_graph(
         List of requested operations. Normally you would use
         ``returned_operations[0].outputs[0]``
     """
-    with tf.gfile.GFile(graph_def_path, "rb") as f:
-        graph_def = tf.GraphDef()
+    with tf.io.gfile.GFile(graph_def_path, "rb") as f:
+        graph_def = tf.compat.v1.GraphDef()
         graph_def.ParseFromString(f.read())
     input_map = {input_name: input}
 
diff --git a/bob/learn/tensorflow/utils/keras.py b/bob/learn/tensorflow/utils/keras.py
index 785c2fb8b9da1f95411077cc784f757053f7027f..e5e1a8f49b3c896a051f56d986adbe89c7f2c30a 100644
--- a/bob/learn/tensorflow/utils/keras.py
+++ b/bob/learn/tensorflow/utils/keras.py
@@ -91,10 +91,10 @@ def restore_model_variables_from_checkpoint(
     model, checkpoint, session=None, normalizer=None
 ):
     if session is None:
-        session = tf.keras.backend.get_session()
+        session = tf.compat.v1.keras.backend.get_session()
 
     var_list = _create_var_map(model.variables, normalizer=normalizer)
-    saver = tf.train.Saver(var_list=var_list)
+    saver = tf.compat.v1.train.Saver(var_list=var_list)
     ckpt_state = tf.train.get_checkpoint_state(checkpoint)
     logger.info("Loading checkpoint %s", ckpt_state.model_checkpoint_path)
     saver.restore(session, ckpt_state.model_checkpoint_path)
@@ -102,7 +102,7 @@ def restore_model_variables_from_checkpoint(
 
 def initialize_model_from_checkpoint(model, checkpoint, normalizer=None):
     assignment_map = _create_var_map(model.variables, normalizer=normalizer)
-    tf.train.init_from_checkpoint(checkpoint, assignment_map=assignment_map)
+    tf.compat.v1.train.init_from_checkpoint(checkpoint, assignment_map=assignment_map)
 
 
 def model_summary(model, do_print=False):
diff --git a/bob/learn/tensorflow/utils/math.py b/bob/learn/tensorflow/utils/math.py
index 64ed7349d7dce745bf3c3fe5ad84bd347a9d493c..b79b4496224958af5eeebb8151dff2e7c18f7ed4 100644
--- a/bob/learn/tensorflow/utils/math.py
+++ b/bob/learn/tensorflow/utils/math.py
@@ -28,7 +28,7 @@ def gram_matrix(input_tensor):
             [0., 0., 0., ..., 0., 0., 0.]],
     """
     result = tf.linalg.einsum("bijc,bijd->bcd", input_tensor, input_tensor)
-    input_shape = tf.shape(input_tensor)
+    input_shape = tf.shape(input=input_tensor)
     num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32)
     return result / (num_locations)
 
@@ -61,17 +61,17 @@ def upper_triangle_and_diagonal(A):
     """
     ones = tf.ones_like(A)
     # Upper triangular matrix of 0s and 1s (including diagonal)
-    mask = tf.matrix_band_part(ones, 0, -1)
-    upper_triangular_flat = tf.boolean_mask(A, mask)
+    mask = tf.linalg.band_part(ones, 0, -1)
+    upper_triangular_flat = tf.boolean_mask(tensor=A, mask=mask)
     return upper_triangular_flat
 
 
 def upper_triangle(A):
     ones = tf.ones_like(A)
     # Upper triangular matrix of 0s and 1s (including diagonal)
-    mask_a = tf.matrix_band_part(ones, 0, -1)
+    mask_a = tf.linalg.band_part(ones, 0, -1)
     # Diagonal
-    mask_b = tf.matrix_band_part(ones, 0, 0)
+    mask_b = tf.linalg.band_part(ones, 0, 0)
     mask = tf.cast(mask_a - mask_b, dtype=tf.bool)
-    upper_triangular_flat = tf.boolean_mask(A, mask)
+    upper_triangular_flat = tf.boolean_mask(tensor=A, mask=mask)
     return upper_triangular_flat
diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py
index 677994b9beae33667c43dcd2a0cf5069a512ce77..f1a2b499ce5f12a7c21aa7b28d82df7e5cf898ff 100644
--- a/bob/learn/tensorflow/utils/reproducible.py
+++ b/bob/learn/tensorflow/utils/reproducible.py
@@ -62,7 +62,7 @@ def set_seed(
     # non-reproducible results.
     # For further details, see:
     # https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
-    session_config = tf.ConfigProto(
+    session_config = tf.compat.v1.ConfigProto(
         intra_op_parallelism_threads=1,
         inter_op_parallelism_threads=1,
         log_device_placement=log_device_placement,
@@ -84,7 +84,7 @@ def set_seed(
     # in the TensorFlow backend have a well-defined initial state.
     # For further details, see:
     # https://www.tensorflow.org/api_docs/python/tf/set_random_seed
-    tf.set_random_seed(seed)
+    tf.compat.v1.set_random_seed(seed)
     # sess = tf.Session(graph=tf.get_default_graph(), config=session_config)
     # keras.backend.set_session(sess)
 
diff --git a/bob/learn/tensorflow/utils/session.py b/bob/learn/tensorflow/utils/session.py
index 3976f9f0446f14a74f0d5b1e237dd28a7937fbb9..67f4856d3047c713d98b4af8ee38c853e914fc7e 100644
--- a/bob/learn/tensorflow/utils/session.py
+++ b/bob/learn/tensorflow/utils/session.py
@@ -15,11 +15,11 @@ class Session(object):
     """
 
     def __init__(self, debug=False):
-        config = tf.ConfigProto(
+        config = tf.compat.v1.ConfigProto(
             log_device_placement=False,
             allow_soft_placement=True,
-            gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.5))
+            gpu_options=tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.5))
         config.gpu_options.allow_growth = True
-        self.session = tf.Session()
+        self.session = tf.compat.v1.Session()
         if debug:
             self.session = tf_debug.LocalCLIDebugWrapperSession(self.session)
diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py
index 263a9ec00970b03786e6e98db35b65617fa323eb..3a09c4551d60199ae6deac24014fa966b63c280e 100644
--- a/bob/learn/tensorflow/utils/util.py
+++ b/bob/learn/tensorflow/utils/util.py
@@ -15,13 +15,13 @@ logger = logging.getLogger(__name__)
 @function.Defun(tf.float32, tf.float32)
 def norm_grad(x, dy):
     return tf.expand_dims(dy, -1) * (
-        x / (tf.expand_dims(tf.norm(x, ord=2, axis=-1), -1) + 1.0e-19)
+        x / (tf.expand_dims(tf.norm(tensor=x, ord=2, axis=-1), -1) + 1.0e-19)
     )
 
 
 @function.Defun(tf.float32, grad_func=norm_grad)
 def norm(x):
-    return tf.norm(x, ord=2, axis=-1)
+    return tf.norm(tensor=x, ord=2, axis=-1)
 
 
 def compute_euclidean_distance(x, y):
@@ -29,7 +29,7 @@ def compute_euclidean_distance(x, y):
     Computes the euclidean distance between two tensorflow variables
     """
 
-    with tf.name_scope("euclidean_distance"):
+    with tf.compat.v1.name_scope("euclidean_distance"):
         # d = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x, y)), 1))
         d = norm(tf.subtract(x, y))
         return d
@@ -38,23 +38,23 @@ def compute_euclidean_distance(x, y):
 def pdist_safe(A, metric="sqeuclidean"):
     if metric != "sqeuclidean":
         raise NotImplementedError()
-    r = tf.reduce_sum(A * A, 1)
+    r = tf.reduce_sum(input_tensor=A * A, axis=1)
     r = tf.reshape(r, [-1, 1])
-    D = r - 2 * tf.matmul(A, A, transpose_b=True) + tf.transpose(r)
+    D = r - 2 * tf.matmul(A, A, transpose_b=True) + tf.transpose(a=r)
     return D
 
 
 def cdist(A, B, metric="sqeuclidean"):
     if metric != "sqeuclidean":
         raise NotImplementedError()
-    M1, M2 = tf.shape(A)[0], tf.shape(B)[0]
+    M1, M2 = tf.shape(input=A)[0], tf.shape(input=B)[0]
     # code from https://stackoverflow.com/a/43839605/1286165
     p1 = tf.matmul(
-        tf.expand_dims(tf.reduce_sum(tf.square(A), 1), 1), tf.ones(shape=(1, M2))
+        tf.expand_dims(tf.reduce_sum(input_tensor=tf.square(A), axis=1), 1), tf.ones(shape=(1, M2))
     )
     p2 = tf.transpose(
-        tf.matmul(
-            tf.reshape(tf.reduce_sum(tf.square(B), 1), shape=[-1, 1]),
+        a=tf.matmul(
+            tf.reshape(tf.reduce_sum(input_tensor=tf.square(B), axis=1), shape=[-1, 1]),
             tf.ones(shape=(M1, 1)),
             transpose_b=True,
         )
@@ -102,7 +102,7 @@ def create_mnist_tfrecord(tfrecords_filename, data, labels, n_samples=6000):
     def _int64_feature(value):
         return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
-    writer = tf.python_io.TFRecordWriter(tfrecords_filename)
+    writer = tf.io.TFRecordWriter(tfrecords_filename)
 
     for i in range(n_samples):
         img = data[i]
@@ -213,14 +213,14 @@ def pdist(A):
     Compute a pairwise euclidean distance in the same fashion
     as in scipy.spation.distance.pdist
     """
-    with tf.name_scope("Pairwisedistance"):
+    with tf.compat.v1.name_scope("Pairwisedistance"):
         ones_1 = tf.reshape(tf.cast(tf.ones_like(A), tf.float32)[:, 0], [1, -1])
-        p1 = tf.matmul(tf.expand_dims(tf.reduce_sum(tf.square(A), 1), 1), ones_1)
+        p1 = tf.matmul(tf.expand_dims(tf.reduce_sum(input_tensor=tf.square(A), axis=1), 1), ones_1)
 
         ones_2 = tf.reshape(tf.cast(tf.ones_like(A), tf.float32)[:, 0], [-1, 1])
         p2 = tf.transpose(
-            tf.matmul(
-                tf.reshape(tf.reduce_sum(tf.square(A), 1), shape=[-1, 1]),
+            a=tf.matmul(
+                tf.reshape(tf.reduce_sum(input_tensor=tf.square(A), axis=1), shape=[-1, 1]),
                 ones_2,
                 transpose_b=True,
             )
@@ -240,8 +240,8 @@ def predict_using_tensors(embedding, labels, num=None):
     inf = tf.cast(tf.ones_like(labels), tf.float32) * numpy.inf
 
     distances = pdist(embedding)
-    distances = tf.matrix_set_diag(distances, inf)
-    indexes = tf.argmin(distances, axis=1)
+    distances = tf.linalg.set_diag(distances, inf)
+    indexes = tf.argmin(input=distances, axis=1)
     return [labels[i] for i in tf.unstack(indexes, num=num)]
 
 
@@ -266,7 +266,7 @@ def compute_embedding_accuracy_tensors(embedding, labels, num=None):
         for p, l in zip(tf.unstack(predictions, num=num), tf.unstack(labels, num=num))
     ]
 
-    return tf.reduce_sum(tf.cast(matching, tf.uint8)) / len(predictions)
+    return tf.reduce_sum(input_tensor=tf.cast(matching, tf.uint8)) / len(predictions)
 
 
 def compute_embedding_accuracy(embedding, labels):
@@ -348,7 +348,7 @@ def to_channels_last(image):
     axis_order = [1, 2, 0]
     shift = ndim - 3
     axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
-    return tf.transpose(image, axis_order)
+    return tf.transpose(a=image, perm=axis_order)
 
 
 def to_channels_first(image):
@@ -379,7 +379,7 @@ def to_channels_first(image):
     axis_order = [2, 0, 1]
     shift = ndim - 3
     axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
-    return tf.transpose(image, axis_order)
+    return tf.transpose(a=image, perm=axis_order)
 
 
 to_skimage = to_matplotlib = to_channels_last
@@ -439,7 +439,7 @@ def random_choice_no_replacement(one_dim_input, num_indices_to_drop=3, sort=Fals
     """Similar to np.random.choice with no replacement.
     Code from https://stackoverflow.com/a/54755281/1286165
     """
-    input_length = tf.shape(one_dim_input)[0]
+    input_length = tf.shape(input=one_dim_input)[0]
 
     # create uniform distribution over the sequence
     uniform_distribution = tf.random.uniform(