diff --git a/bob/learn/tensorflow/layers.py b/bob/learn/tensorflow/layers.py
index 046175f03c5d8de3a41ac8762557d52841f9d6ec..ba63ca9521aaab3c0c8f4a0dc92e60a9745d93bf 100644
--- a/bob/learn/tensorflow/layers.py
+++ b/bob/learn/tensorflow/layers.py
@@ -1,6 +1,7 @@
 import numbers
 
 import tensorflow as tf
+import math
 
 
 def _check_input(
@@ -160,3 +161,144 @@ def Normalize(mean, std=1.0, **kwargs):
     return tf.keras.layers.experimental.preprocessing.Rescaling(
         scale=scale, offset=offset, **kwargs
     )
+
+
+class SphereFaceLayer(tf.keras.layers.Layer):
+    """
+    Implements the SphereFace loss from equation (7) of `SphereFace: Deep Hypersphere Embedding for Face Recognition <https://arxiv.org/abs/1704.08063>`_
+
+    If the parameter `original` is set to `True` it will computes exactly what's written in eq (7): :math:`\\text{soft}(x_i) = \\frac{exp(||x_i||\\text{cos}(\\psi(\\theta_{yi})))}{exp(||x_i||\\text{cos}(\\psi(\\theta_{yi}))) + \sum_{j;j\\neq yi}  exp(||x_i||\\text{cos}(\psi(\\theta_{j}))) }`.
+    Where :math:`\\psi(\\theta) = -1^k \\text{cos}(m\\theta)-2k`.
+
+    Parameters
+    ----------
+
+      n_classes: int
+        Number of classes
+
+      m: float
+         Margin
+            
+    """
+
+    def __init__(self, n_classes=10, m=0.5):
+        super(SphereFaceLayer, self).__init__(name="sphere_face_logits")
+        self.n_classes = n_classes
+        self.m = m
+
+    def build(self, input_shape):
+        super(SphereFaceLayer, self).build(input_shape[0])
+        shape = [input_shape[-1], self.n_classes]
+
+        self.W = self.add_variable("W", shape=shape)
+        self.pi = tf.constant(math.pi)
+
+    def call(self, X, training=None):
+
+        # normalize feature
+        X = tf.nn.l2_normalize(X, axis=1)
+        W = tf.nn.l2_normalize(self.W, axis=0)
+
+        # cos between X and W
+        cos_yi = tf.matmul(X, W)
+        cos_yi = tf.clip_by_value(cos_yi, -1, 1)
+
+        # cos(m \theta)
+        theta = tf.math.acos(cos_yi)
+        cos_theta_m = tf.math.cos(self.m * theta)
+
+        # ||x||
+        x_norm = tf.norm(X, axis=-1, keepdims=True)
+
+        # phi = -1**k * cos(m \theta) - 2k
+        k = self.m * (theta / self.pi)
+        phi = ((-(1 ** k)) * cos_theta_m) - 2 * k
+
+        logits = x_norm * phi
+
+        return logits
+
+
+class ModifiedSoftMaxLayer(tf.keras.layers.Layer):
+    """
+    Implements the modified logit from equation (5) of `SphereFace: Deep Hypersphere Embedding for Face Recognition <https://arxiv.org/pdf/1704.08063.pdf>`_
+
+    It basically transforms the regular logit function to :math:`||x_i||cos(\\theta_{yi})`, where :math:`\\theta_{yi}=||x_i||_2^2||W||_2^2`
+
+    Parameters
+    ----------
+    
+    n_classes: int
+        Number of classes for the new logit function
+    """
+
+    def __init__(self, n_classes=10):
+
+        super(ModifiedSoftMaxLayer, self).__init__(name="modified_softmax_logits")
+        self.n_classes = n_classes
+
+    def build(self, input_shape):
+        super(ModifiedSoftMaxLayer, self).build(input_shape[0])
+        shape = [input_shape[-1], self.n_classes]
+
+        self.W = self.add_variable("W", shape=shape)
+
+    def call(self, X, training=None):
+
+        # normalize feature
+        W = tf.nn.l2_normalize(self.W, axis=0)
+
+        # cos between X and W
+        cos_yi = tf.nn.l2_normalize(X, axis=1) @ W
+
+        logits = tf.norm(X) * cos_yi
+
+        return logits
+
+
+from tensorflow.keras.layers import (
+    BatchNormalization,
+    Dropout,
+    Dense,
+    Concatenate,
+    GlobalAvgPool2D,
+)
+
+
+def add_bottleneck(model, bottleneck_size=128, dropout_rate=0.2):
+    """
+    Amend a bottleneck layer to a Keras Model
+
+    Parameters
+    ----------
+
+      model:
+        Keras model
+
+      bottleneck_size: int
+         Size of the bottleneck
+
+      dropout_rate: float
+         Dropout rate
+    """
+    if not isinstance(model, tf.keras.models.Sequential):
+        new_model = tf.keras.models.Sequential(model, name="bottleneck")
+    else:
+        new_model = model
+
+    new_model.add(GlobalAvgPool2D())
+    new_model.add(Dropout(dropout_rate, name="Dropout"))
+    new_model.add(Dense(bottleneck_size, use_bias=False, name="embeddings"))
+    new_model.add(BatchNormalization(axis=-1, scale=False, name="embeddings/BatchNorm"))
+
+    return new_model
+
+
+def add_top(model, n_classes):
+    if not isinstance(model, tf.keras.models.Sequential):
+        new_model = tf.keras.models.Sequential(model, name="logits")
+    else:
+        new_model = model
+
+    new_model.add(Dense(n_classes, name="logits"))
+    return new_model
diff --git a/bob/learn/tensorflow/models/__init__.py b/bob/learn/tensorflow/models/__init__.py
index 18a84faeb4c2f66e7929ccb4c39488d23da5a5ff..b5cd87bf26c41a6ad1418c78bbefa1ac221640b6 100644
--- a/bob/learn/tensorflow/models/__init__.py
+++ b/bob/learn/tensorflow/models/__init__.py
@@ -3,7 +3,8 @@ from .densenet import DeepPixBiS
 from .densenet import DenseNet
 from .densenet import densenet161  # noqa: F401
 from .mine import MineModel
-
+from .embedding_validation import EmbeddingValidation
+from .arcface import ArcFaceLayer, ArcFaceLayer3Penalties, ArcFaceModel
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
@@ -21,5 +22,14 @@ def __appropriate__(*args):
         obj.__module__ = __name__
 
 
-__appropriate__(AlexNet_simplified, DenseNet, DeepPixBiS, MineModel)
+__appropriate__(
+    AlexNet_simplified,
+    DenseNet,
+    DeepPixBiS,
+    MineModel,
+    ArcFaceLayer,
+    ArcFaceLayer3Penalties,
+    ArcFaceModel,
+    EmbeddingValidation,
+)
 __all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/bob/learn/tensorflow/models/arcface.py b/bob/learn/tensorflow/models/arcface.py
new file mode 100644
index 0000000000000000000000000000000000000000..59f8ff68a1396598fa14bf8c6958e225e92a7592
--- /dev/null
+++ b/bob/learn/tensorflow/models/arcface.py
@@ -0,0 +1,151 @@
+import tensorflow as tf
+from .embedding_validation import EmbeddingValidation
+from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
+import math
+
+
+class ArcFaceModel(EmbeddingValidation):
+    def train_step(self, data):
+        X, y = data
+
+        with tf.GradientTape() as tape:
+
+            logits, _ = self((X, y), training=True)
+            loss = self.compiled_loss(
+                y, logits, sample_weight=None, regularization_losses=self.losses
+            )
+        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
+        self.compiled_metrics.update_state(y, logits, sample_weight=None)
+
+        self.train_loss(loss)
+        return {m.name: m.result() for m in self.metrics + [self.train_loss]}
+
+    def test_step(self, data):
+        """
+        Test Step
+        """
+
+        images, labels = data
+
+        # No worries, labels not used in validation
+        _, embeddings = self((images, labels), training=False)
+        self.validation_acc(accuracy_from_embeddings(labels, embeddings))
+        return {m.name: m.result() for m in [self.validation_acc]}
+
+
+class ArcFaceLayer(tf.keras.layers.Layer):
+    """
+    Implements the ArcFace from equation (3) of `ArcFace: Additive Angular Margin Loss for Deep Face Recognition <https://arxiv.org/abs/1801.07698>`_
+
+    Defined as:
+
+    :math:`s(cos(\\theta_i) + m`
+
+    Parameters
+    ----------
+
+      n_classes: int
+        Number of classes
+
+      m: float
+         Margin  
+
+      s: int
+         Scale  
+    """
+
+    def __init__(self, n_classes=10, s=30, m=0.5):
+        super(ArcFaceLayer, self).__init__(name="arc_face_logits")
+        self.n_classes = n_classes
+        self.s = s
+        self.m = m
+
+    def build(self, input_shape):
+        super(ArcFaceLayer, self).build(input_shape[0])
+        shape = [input_shape[-1], self.n_classes]
+
+        self.W = self.add_variable("W", shape=shape)
+
+        self.cos_m = tf.identity(math.cos(self.m), name="cos_m")
+        self.sin_m = tf.identity(math.sin(self.m), name="sin_m")
+        self.th = tf.identity(math.cos(math.pi - self.m), name="th")
+        self.mm = tf.identity(math.sin(math.pi - self.m) * self.m)
+
+    def call(self, X, y, training=None):
+
+        # normalize feature
+        X = tf.nn.l2_normalize(X, axis=1)
+        W = tf.nn.l2_normalize(self.W, axis=0)
+
+        # cos between X and W
+        cos_yi = tf.matmul(X, W)
+
+        # sin_yi = tf.math.sqrt(1-cos_yi**2)
+        sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1)
+
+        # cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m)
+        cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m
+
+        cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm)
+
+        # Preparing the hot-output
+        one_hot = tf.one_hot(
+            tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
+        )
+
+        logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
+        logits = self.s * logits
+
+        return logits
+
+
+class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
+    """
+    Implements the ArcFace loss from equation (4) of `ArcFace: Additive Angular Margin Loss for Deep Face Recognition <https://arxiv.org/abs/1801.07698>`_
+    
+    Defined as:
+    
+      :math:`s(cos(m_1\\theta_i + m_2) -m_3`
+    """
+
+    def __init__(self, n_classes=10, s=30, m1=0.5, m2=0.5, m3=0.5):
+        super(ArcFaceLayer3Penalties, self).__init__(name="arc_face_logits")
+        self.n_classes = n_classes
+        self.s = s
+        self.m1 = m1
+        self.m2 = m2
+        self.m3 = m3
+
+    def build(self, input_shape):
+        super(ArcFaceLayer3Penalties, self).build(input_shape[0])
+        shape = [input_shape[-1], self.n_classes]
+
+        self.W = self.add_variable("W", shape=shape)
+
+    def call(self, X, y, training=None):
+
+        # normalize feature
+        X = tf.nn.l2_normalize(X, axis=1)
+        W = tf.nn.l2_normalize(self.W, axis=0)
+
+        # cos between X and W
+        cos_yi = tf.matmul(X, W)
+
+        # Getting the angle
+        theta = tf.math.acos(cos_yi)
+
+        cos_yi_m = tf.math.cos(self.m1 * theta + self.m2) - self.m3
+
+        # logits = self.s*cos_theta_m
+
+        # Preparing the hot-output
+        one_hot = tf.one_hot(
+            tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
+        )
+
+        logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
+
+        logits = self.s * logits
+
+        return logits
+
diff --git a/bob/learn/tensorflow/models/embedding_validation.py b/bob/learn/tensorflow/models/embedding_validation.py
new file mode 100644
index 0000000000000000000000000000000000000000..701f3084e74267bdfb0e300b57a723b3beef4ac3
--- /dev/null
+++ b/bob/learn/tensorflow/models/embedding_validation.py
@@ -0,0 +1,51 @@
+import tensorflow as tf
+from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
+
+
+class EmbeddingValidation(tf.keras.Model):
+    """
+    Use this model if the validation step should validate the accuracy with respect to embeddings.
+    
+    In this model, the `test_step` runs the function `bob.learn.tensorflow.metrics.embedding_accuracy.accuracy_from_embeddings`
+    """
+
+    def compile(
+        self, **kwargs,
+    ):
+        """
+        Compile
+        """
+        super().compile(**kwargs)
+        self.train_loss = tf.keras.metrics.Mean(name="accuracy")
+        self.validation_acc = tf.keras.metrics.Mean(name="accuracy")
+
+    def train_step(self, data):
+        """
+        Train Step
+        """
+
+        X, y = data
+        with tf.GradientTape() as tape:
+            logits, _ = self(X, training=True)
+            loss = self.loss(y, logits)
+
+        trainable_vars = self.trainable_variables
+        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
+
+        self.compiled_metrics.update_state(y, logits, sample_weight=None)
+        self.train_loss(loss)
+        return {m.name: m.result() for m in self.metrics + [self.train_loss]}
+
+        # self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+        # self.train_loss(loss)
+        # return {m.name: m.result() for m in [self.train_loss]}
+
+    def test_step(self, data):
+        """
+        Test Step
+        """
+
+        images, labels = data
+        logits, prelogits = self(images, training=False)
+        self.validation_acc(accuracy_from_embeddings(labels, prelogits))
+        return {m.name: m.result() for m in [self.validation_acc]}
diff --git a/bob/learn/tensorflow/tests/test_arcface.py b/bob/learn/tensorflow/tests/test_arcface.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac9152003a86751e51415d3e58a9ef4f7d68685e
--- /dev/null
+++ b/bob/learn/tensorflow/tests/test_arcface.py
@@ -0,0 +1,50 @@
+from bob.learn.tensorflow.models import (
+    EmbeddingValidation,
+    ArcFaceLayer,
+    ArcFaceModel,
+    ArcFaceLayer3Penalties,
+)
+from bob.learn.tensorflow.layers import (
+    SphereFaceLayer,
+    ModifiedSoftMaxLayer,
+)
+
+import numpy as np
+
+
+def test_arcface_layer():
+
+    layer = ArcFaceLayer()
+    np.random.seed(10)
+    X = np.random.rand(10, 50)
+    y = [np.random.randint(10) for i in range(10)]
+
+    assert layer(X, y).shape == (10, 10)
+
+
+def test_arcface_layer_3p():
+
+    layer = ArcFaceLayer3Penalties()
+    np.random.seed(10)
+    X = np.random.rand(10, 50)
+    y = [np.random.randint(10) for i in range(10)]
+
+    assert layer(X, y).shape == (10, 10)
+
+
+def test_sphereface():
+
+    layer = SphereFaceLayer()
+    np.random.seed(10)
+    X = np.random.rand(10, 10)
+
+    assert layer(X).shape == (10, 10)
+
+
+def test_modsoftmax():
+
+    layer = ModifiedSoftMaxLayer()
+    np.random.seed(10)
+    X = np.random.rand(10, 10)
+
+    assert layer(X).shape == (10, 10)