diff --git a/bob/learn/tensorflow/layers.py b/bob/learn/tensorflow/layers.py
index e15d36e5ffaedc4f321296486350b6248c18d252..337858788de27daa927b7db4311f7c33f1fa6b3f 100644
--- a/bob/learn/tensorflow/layers.py
+++ b/bob/learn/tensorflow/layers.py
@@ -5,7 +5,6 @@ import tensorflow as tf
 from tensorflow.keras.layers import BatchNormalization
 from tensorflow.keras.layers import Dense
 from tensorflow.keras.layers import Dropout
-from tensorflow.keras.layers import GlobalAvgPool2D
 
 
 def _check_input(
@@ -260,7 +259,12 @@ class ModifiedSoftMaxLayer(tf.keras.layers.Layer):
         return logits
 
 
-def add_bottleneck(model, bottleneck_size=128, dropout_rate=0.2):
+from tensorflow.keras.layers import Flatten
+
+
+def add_bottleneck(
+    model, bottleneck_size=128, dropout_rate=0.2, w_decay=5e-4, use_bias=True
+):
     """
     Amend a bottleneck layer to a Keras Model
 
@@ -276,15 +280,31 @@ def add_bottleneck(model, bottleneck_size=128, dropout_rate=0.2):
       dropout_rate: float
          Dropout rate
     """
+
     if not isinstance(model, tf.keras.models.Sequential):
         new_model = tf.keras.models.Sequential(model, name="bottleneck")
     else:
         new_model = model
 
-    new_model.add(GlobalAvgPool2D())
+    new_model.add(BatchNormalization())
     new_model.add(Dropout(dropout_rate, name="Dropout"))
-    new_model.add(Dense(bottleneck_size, use_bias=False, name="embeddings"))
-    new_model.add(BatchNormalization(axis=-1, scale=False, name="embeddings/BatchNorm"))
+    new_model.add(Flatten())
+
+    if w_decay is None:
+        regularizer = None
+    else:
+        regularizer = tf.keras.regularizers.l2(w_decay)
+
+    new_model.add(
+        Dense(
+            bottleneck_size,
+            use_bias=use_bias,
+            kernel_regularizer=regularizer,
+        )
+    )
+
+    new_model.add(BatchNormalization(axis=-1, name="embeddings"))
+    # new_model.add(BatchNormalization())
 
     return new_model
 
diff --git a/bob/learn/tensorflow/models/__init__.py b/bob/learn/tensorflow/models/__init__.py
index 4265c99d7b58fe0bf723afb04883d48a09633df1..333ab1f4f7446e182a5ce7348eb1f1eb3a747cac 100644
--- a/bob/learn/tensorflow/models/__init__.py
+++ b/bob/learn/tensorflow/models/__init__.py
@@ -7,6 +7,8 @@ from .densenet import DenseNet
 from .densenet import densenet161  # noqa: F401
 from .embedding_validation import EmbeddingValidation
 from .mine import MineModel
+from .resnet50_modified import resnet50_modified  # noqa: F401
+from .resnet50_modified import resnet101_modified  # noqa: F401
 
 
 # gets sphinx autodoc done right - don't remove it
diff --git a/bob/learn/tensorflow/models/embedding_validation.py b/bob/learn/tensorflow/models/embedding_validation.py
index be342d1f5c22404b0145f78c8cfc71e33e18c792..beb4498bc3e1465c496f367bc1370b95f6a36168 100644
--- a/bob/learn/tensorflow/models/embedding_validation.py
+++ b/bob/learn/tensorflow/models/embedding_validation.py
@@ -12,6 +12,7 @@ class EmbeddingValidation(tf.keras.Model):
 
     def compile(
         self,
+        single_precision=False,
         **kwargs,
     ):
         """
@@ -27,14 +28,20 @@ class EmbeddingValidation(tf.keras.Model):
         """
 
         X, y = data
+
         with tf.GradientTape() as tape:
             logits, _ = self(X, training=True)
             loss = self.loss(y, logits)
 
+        # trainable_vars = self.trainable_variables
+
         self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
 
         self.compiled_metrics.update_state(y, logits, sample_weight=None)
         self.train_loss(loss)
+
+        tf.summary.scalar("training_loss", data=loss, step=self._train_counter)
+
         return {m.name: m.result() for m in self.metrics + [self.train_loss]}
 
         # self.optimizer.apply_gradients(zip(gradients, trainable_vars))
diff --git a/bob/learn/tensorflow/models/resnet50_modified.py b/bob/learn/tensorflow/models/resnet50_modified.py
new file mode 100644
index 0000000000000000000000000000000000000000..6725301c03bfdaf3e2db15c287471ff0543c82b4
--- /dev/null
+++ b/bob/learn/tensorflow/models/resnet50_modified.py
@@ -0,0 +1,351 @@
+# -*- coding: utf-8 -*-
+"""
+The resnet50 from `tf.keras.applications.Resnet50` has a problem with the convolutional layers.
+It basically add bias terms to such layers followed by batch normalizations, which is not correct
+
+https://github.com/tensorflow/tensorflow/issues/37365
+
+This resnet 50 implementation provides a cleaner version
+"""
+
+import tensorflow as tf
+from tensorflow.keras.layers import Activation
+from tensorflow.keras.layers import BatchNormalization
+from tensorflow.keras.layers import Conv2D
+from tensorflow.keras.layers import MaxPooling2D
+from tensorflow.keras.regularizers import l2
+
+global weight_decay
+weight_decay = 1e-4
+
+
+class IdentityBlock(tf.keras.layers.Layer):
+    def __init__(
+        self, kernel_size, filters, stage, block, weight_decay=1e-4, name=None, **kwargs
+    ):
+
+        """Block that has no convolutianal layer as skip connection
+
+        Parameters
+        ----------
+            kernel_size:
+               The kernel size of middle conv layer at main path
+
+            filters:
+                list of integers, the filterss of 3 conv layer at main path
+            stage:
+              Current stage label, used for generating layer names
+
+            block:
+                'a','b'..., current block label, used for generating layer names
+
+        """
+        super().__init__(name=name, **kwargs)
+
+        filters1, filters2, filters3 = filters
+        bn_axis = 3
+
+        conv_name_1 = "conv" + str(stage) + "_" + str(block) + "_1x1_reduce"
+        bn_name_1 = "conv" + str(stage) + "_" + str(block) + "_1x1_reduce/bn"
+        layers = [
+            Conv2D(
+                filters1,
+                (1, 1),
+                kernel_initializer="orthogonal",
+                use_bias=False,
+                kernel_regularizer=l2(weight_decay),
+                name=conv_name_1,
+            )
+        ]
+
+        layers += [BatchNormalization(axis=bn_axis, name=bn_name_1)]
+        layers += [Activation("relu")]
+
+        conv_name_2 = "conv" + str(stage) + "_" + str(block) + "_3x3"
+        bn_name_2 = "conv" + str(stage) + "_" + str(block) + "_3x3/bn"
+        layers += [
+            Conv2D(
+                filters2,
+                kernel_size,
+                padding="same",
+                kernel_initializer="orthogonal",
+                use_bias=False,
+                kernel_regularizer=l2(weight_decay),
+                name=conv_name_2,
+            )
+        ]
+        layers += [BatchNormalization(axis=bn_axis, name=bn_name_2)]
+        layers += [Activation("relu")]
+
+        conv_name_3 = "conv" + str(stage) + "_" + str(block) + "_1x1_increase"
+        bn_name_3 = "conv" + str(stage) + "_" + str(block) + "_1x1_increase/bn"
+        layers += [
+            Conv2D(
+                filters3,
+                (1, 1),
+                kernel_initializer="orthogonal",
+                use_bias=False,
+                kernel_regularizer=l2(weight_decay),
+                name=conv_name_3,
+            )
+        ]
+        layers += [BatchNormalization(axis=bn_axis, name=bn_name_3)]
+        self.layers = layers
+
+    def call(self, input_tensor, training=None):
+
+        x = input_tensor
+        for lay in self.layers:
+            x = lay(x, training=training)
+
+        x = tf.keras.layers.add([x, input_tensor])
+        x = Activation("relu")(x)
+
+        return x
+
+
+class ConvBlock(tf.keras.layers.Layer):
+    def __init__(
+        self,
+        kernel_size,
+        filters,
+        stage,
+        block,
+        strides=(2, 2),
+        weight_decay=1e-4,
+        name=None,
+        **kwargs,
+    ):
+        """Block that has a conv layer AS shortcut.
+        Parameters
+        ----------
+            kernel_size:
+               The kernel size of middle conv layer at main path
+
+            filters:
+                list of integers, the filterss of 3 conv layer at main path
+            stage:
+              Current stage label, used for generating layer names
+
+            block:
+                'a','b'..., current block label, used for generating layer names
+        """
+        super().__init__(name=name, **kwargs)
+
+        filters1, filters2, filters3 = filters
+        bn_axis = 3
+
+        conv_name_1 = "conv" + str(stage) + "_" + str(block) + "_1x1_reduce"
+        bn_name_1 = "conv" + str(stage) + "_" + str(block) + "_1x1_reduce/bn"
+        layers = [
+            Conv2D(
+                filters1,
+                (1, 1),
+                strides=strides,
+                kernel_initializer="orthogonal",
+                use_bias=False,
+                kernel_regularizer=l2(weight_decay),
+                name=conv_name_1,
+            )
+        ]
+        layers += [BatchNormalization(axis=bn_axis, name=bn_name_1)]
+        layers += [Activation("relu")]
+
+        conv_name_2 = "conv" + str(stage) + "_" + str(block) + "_3x3"
+        bn_name_2 = "conv" + str(stage) + "_" + str(block) + "_3x3/bn"
+        layers += [
+            Conv2D(
+                filters2,
+                kernel_size,
+                padding="same",
+                kernel_initializer="orthogonal",
+                use_bias=False,
+                kernel_regularizer=l2(weight_decay),
+                name=conv_name_2,
+            )
+        ]
+        layers += [BatchNormalization(axis=bn_axis, name=bn_name_2)]
+        layers += [Activation("relu")]
+
+        conv_name_3 = "conv" + str(stage) + "_" + str(block) + "_1x1_increase"
+        bn_name_3 = "conv" + str(stage) + "_" + str(block) + "_1x1_increase/bn"
+        layers += [
+            Conv2D(
+                filters3,
+                (1, 1),
+                kernel_initializer="orthogonal",
+                use_bias=False,
+                kernel_regularizer=l2(weight_decay),
+                name=conv_name_3,
+            )
+        ]
+        layers += [BatchNormalization(axis=bn_axis, name=bn_name_3)]
+
+        conv_name_4 = "conv" + str(stage) + "_" + str(block) + "_1x1_proj"
+        bn_name_4 = "conv" + str(stage) + "_" + str(block) + "_1x1_proj/bn"
+        shortcut = [
+            Conv2D(
+                filters3,
+                (1, 1),
+                strides=strides,
+                kernel_initializer="orthogonal",
+                use_bias=False,
+                kernel_regularizer=l2(weight_decay),
+                name=conv_name_4,
+            )
+        ]
+        shortcut += [BatchNormalization(axis=bn_axis, name=bn_name_4)]
+
+        self.layers = layers
+        self.shortcut = shortcut
+
+    def call(self, input_tensor, training=None):
+        x = input_tensor
+        for lay in self.layers:
+            x = lay(x, training=training)
+
+        x_s = input_tensor
+        for lay in self.shortcut:
+            x_s = lay(x_s, training=training)
+
+        x = tf.keras.layers.add([x, x_s])
+        x = Activation("relu")(x)
+        return x
+
+
+def resnet50_modified(input_tensor=None, input_shape=None, **kwargs):
+    """
+    The resnet50 from `tf.keras.applications.Resnet50` has a problem with the convolutional layers.
+    It basically add bias terms to such layers followed by batch normalizations, which is not correct
+
+    https://github.com/tensorflow/tensorflow/issues/37365
+
+    This resnet 50 implementation provides a cleaner version
+
+    """
+    if input_tensor is None:
+        input_tensor = tf.keras.Input(shape=input_shape)
+    else:
+        if not tf.keras.backend.is_keras_tensor(input_tensor):
+            input_tensor = tf.keras.Input(tensor=input_tensor, shape=input_shape)
+
+    bn_axis = 3
+    # inputs are of size 224 x 224 x 3
+    layers = [input_tensor]
+    layers += [
+        Conv2D(
+            64,
+            (7, 7),
+            strides=(2, 2),
+            kernel_initializer="orthogonal",
+            use_bias=False,
+            trainable=True,
+            kernel_regularizer=l2(weight_decay),
+            padding="same",
+            name="conv1/7x7_s2",
+        )
+    ]
+
+    # inputs are of size 112 x 112 x 64
+    layers += [BatchNormalization(axis=bn_axis, name="conv1/7x7_s2/bn")]
+    layers += [Activation("relu")]
+    layers += [MaxPooling2D((3, 3), strides=(2, 2))]
+
+    # inputs are of size 56 x 56
+    layers += [ConvBlock(3, [64, 64, 256], stage=2, block=1, strides=(1, 1))]
+    layers += [IdentityBlock(3, [64, 64, 256], stage=2, block=2)]
+    layers += [IdentityBlock(3, [64, 64, 256], stage=2, block=3)]
+
+    # inputs are of size 28 x 28
+    layers += [ConvBlock(3, [128, 128, 512], stage=3, block=1)]
+    layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=2)]
+    layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=3)]
+    layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=4)]
+
+    # inputs are of size 14 x 14
+    layers += [ConvBlock(3, [256, 256, 1024], stage=4, block=1)]
+    layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=2)]
+    layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=3)]
+    layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=4)]
+    layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=5)]
+    layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=6)]
+
+    # inputs are of size 7 x 7
+    layers += [ConvBlock(3, [512, 512, 2048], stage=5, block=1)]
+    layers += [IdentityBlock(3, [512, 512, 2048], stage=5, block=2)]
+    layers += [IdentityBlock(3, [512, 512, 2048], stage=5, block=3)]
+
+    return tf.keras.Sequential(layers)
+
+
+def resnet101_modified(input_tensor=None, input_shape=None, **kwargs):
+    """
+    The resnet101 from `tf.keras.applications.Resnet101` has a problem with the convolutional layers.
+    It basically add bias terms to such layers followed by batch normalizations, which is not correct
+
+    https://github.com/tensorflow/tensorflow/issues/37365
+
+    This resnet 10 implementation provides a cleaner version
+
+    """
+
+    if input_tensor is None:
+        input_tensor = tf.keras.Input(shape=input_shape)
+    else:
+        if not tf.keras.backend.is_keras_tensor(input_tensor):
+            input_tensor = tf.keras.Input(tensor=input_tensor, shape=input_shape)
+
+    bn_axis = 3
+    # inputs are of size 224 x 224 x 3
+    layers = [input_tensor]
+    layers += [
+        Conv2D(
+            64,
+            (7, 7),
+            strides=(2, 2),
+            kernel_initializer="orthogonal",
+            use_bias=False,
+            trainable=True,
+            kernel_regularizer=l2(weight_decay),
+            padding="same",
+            name="conv1/7x7_s2",
+        )
+    ]
+
+    # inputs are of size 112 x 112 x 64
+    layers += [BatchNormalization(axis=bn_axis, name="conv1/7x7_s2/bn")]
+    layers += [Activation("relu")]
+    layers += [MaxPooling2D((3, 3), strides=(2, 2))]
+
+    # inputs are of size 56 x 56
+    layers += [ConvBlock(3, [64, 64, 256], stage=2, block=1, strides=(1, 1))]
+    layers += [IdentityBlock(3, [64, 64, 256], stage=2, block=2)]
+    layers += [IdentityBlock(3, [64, 64, 256], stage=2, block=3)]
+
+    # inputs are of size 28 x 28
+    layers += [ConvBlock(3, [128, 128, 512], stage=3, block=1)]
+    layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=2)]
+    layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=3)]
+    layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=4)]
+
+    # inputs are of size 14 x 14
+    # 23 blocks here. That's the only difference from
+    # resnet-101
+    layers += [ConvBlock(3, [256, 256, 1024], stage=4, block=1)]
+    for i in range(2, 24):
+        layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=i)]
+
+    # inputs are of size 7 x 7
+    layers += [ConvBlock(3, [512, 512, 2048], stage=5, block=1)]
+    layers += [IdentityBlock(3, [512, 512, 2048], stage=5, block=2)]
+    layers += [IdentityBlock(3, [512, 512, 2048], stage=5, block=3)]
+
+    return tf.keras.Sequential(layers)
+
+
+if __name__ == "__main__":
+    input_tensor = tf.keras.layers.InputLayer([112, 112, 3])
+    model = resnet50_modified(input_tensor)
+
+    print(len(model.variables))
+    print(model.summary())