From d255f06e25cee2b0de0c21d59da2f07a6732ecfc Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Wed, 9 Jun 2021 17:20:11 +0200
Subject: [PATCH] Updated ARCFACE

---
 bob/learn/tensorflow/models/arcface.py | 59 +++++++++++++++++---------
 1 file changed, 39 insertions(+), 20 deletions(-)

diff --git a/bob/learn/tensorflow/models/arcface.py b/bob/learn/tensorflow/models/arcface.py
index 8bbb7a7d..5c856900 100644
--- a/bob/learn/tensorflow/models/arcface.py
+++ b/bob/learn/tensorflow/models/arcface.py
@@ -17,9 +17,18 @@ class ArcFaceModel(EmbeddingValidation):
             loss = self.compiled_loss(
                 y, logits, sample_weight=None, regularization_losses=self.losses
             )
-        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
+            reg_loss = tf.reduce_sum(self.losses)
+            total_loss = loss + reg_loss
+
+        trainable_vars = self.trainable_variables
+
+        self.optimizer.minimize(total_loss, trainable_vars, tape=tape)
+
         self.compiled_metrics.update_state(y, logits, sample_weight=None)
 
+        tf.summary.scalar("arc_face_loss", data=loss, step=self._train_counter)
+        tf.summary.scalar("total_loss", data=total_loss, step=self._train_counter)
+
         self.train_loss(loss)
         return {m.name: m.result() for m in self.metrics + [self.train_loss]}
 
@@ -55,12 +64,16 @@ class ArcFaceLayer(tf.keras.layers.Layer):
 
       s: int
          Scale
+
+      arc: bool
+         If `True`, uses arcface loss. If `False`, it's a regular dense layer
     """
 
-    def __init__(self, n_classes=10, s=30, m=0.5):
+    def __init__(self, n_classes=10, s=30, m=0.5, arc=True):
         super(ArcFaceLayer, self).__init__(name="arc_face_logits")
         self.n_classes = n_classes
         self.s = s
+        self.arc = arc
         self.m = m
 
     def build(self, input_shape):
@@ -75,29 +88,31 @@ class ArcFaceLayer(tf.keras.layers.Layer):
         self.mm = tf.identity(math.sin(math.pi - self.m) * self.m)
 
     def call(self, X, y, training=None):
+        if self.arc:
+            # normalize feature
+            X = tf.nn.l2_normalize(X, axis=1)
+            W = tf.nn.l2_normalize(self.W, axis=0)
 
-        # normalize feature
-        X = tf.nn.l2_normalize(X, axis=1)
-        W = tf.nn.l2_normalize(self.W, axis=0)
-
-        # cos between X and W
-        cos_yi = tf.matmul(X, W)
+            # cos between X and W
+            cos_yi = tf.matmul(X, W)
 
-        # sin_yi = tf.math.sqrt(1-cos_yi**2)
-        sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1)
+            # sin_yi = tf.math.sqrt(1-cos_yi**2)
+            sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1)
 
-        # cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m)
-        cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m
+            # cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m)
+            cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m
 
-        cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm)
+            cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm)
 
-        # Preparing the hot-output
-        one_hot = tf.one_hot(
-            tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
-        )
+            # Preparing the hot-output
+            one_hot = tf.one_hot(
+                tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
+            )
 
-        logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
-        logits = self.s * logits
+            logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
+            logits = self.s * logits
+        else:
+            logits = tf.matmul(X, self.W)
 
         return logits
 
@@ -136,6 +151,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
 
         # Getting the angle
         theta = tf.math.acos(cos_yi)
+        theta = tf.clip_by_value(
+            theta, -1.0 + tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()
+        )
 
         cos_yi_m = tf.math.cos(self.m1 * theta + self.m2) - self.m3
 
@@ -146,8 +164,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
             tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
         )
 
+        one_hot = tf.cast(one_hot, cos_yi_m.dtype)
+
         logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
 
         logits = self.s * logits
-
         return logits
-- 
GitLab