From d255f06e25cee2b0de0c21d59da2f07a6732ecfc Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 9 Jun 2021 17:20:11 +0200 Subject: [PATCH] Updated ARCFACE --- bob/learn/tensorflow/models/arcface.py | 59 +++++++++++++++++--------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/bob/learn/tensorflow/models/arcface.py b/bob/learn/tensorflow/models/arcface.py index 8bbb7a7d..5c856900 100644 --- a/bob/learn/tensorflow/models/arcface.py +++ b/bob/learn/tensorflow/models/arcface.py @@ -17,9 +17,18 @@ class ArcFaceModel(EmbeddingValidation): loss = self.compiled_loss( y, logits, sample_weight=None, regularization_losses=self.losses ) - self.optimizer.minimize(loss, self.trainable_variables, tape=tape) + reg_loss = tf.reduce_sum(self.losses) + total_loss = loss + reg_loss + + trainable_vars = self.trainable_variables + + self.optimizer.minimize(total_loss, trainable_vars, tape=tape) + self.compiled_metrics.update_state(y, logits, sample_weight=None) + tf.summary.scalar("arc_face_loss", data=loss, step=self._train_counter) + tf.summary.scalar("total_loss", data=total_loss, step=self._train_counter) + self.train_loss(loss) return {m.name: m.result() for m in self.metrics + [self.train_loss]} @@ -55,12 +64,16 @@ class ArcFaceLayer(tf.keras.layers.Layer): s: int Scale + + arc: bool + If `True`, uses arcface loss. If `False`, it's a regular dense layer """ - def __init__(self, n_classes=10, s=30, m=0.5): + def __init__(self, n_classes=10, s=30, m=0.5, arc=True): super(ArcFaceLayer, self).__init__(name="arc_face_logits") self.n_classes = n_classes self.s = s + self.arc = arc self.m = m def build(self, input_shape): @@ -75,29 +88,31 @@ class ArcFaceLayer(tf.keras.layers.Layer): self.mm = tf.identity(math.sin(math.pi - self.m) * self.m) def call(self, X, y, training=None): + if self.arc: + # normalize feature + X = tf.nn.l2_normalize(X, axis=1) + W = tf.nn.l2_normalize(self.W, axis=0) - # normalize feature - X = tf.nn.l2_normalize(X, axis=1) - W = tf.nn.l2_normalize(self.W, axis=0) - - # cos between X and W - cos_yi = tf.matmul(X, W) + # cos between X and W + cos_yi = tf.matmul(X, W) - # sin_yi = tf.math.sqrt(1-cos_yi**2) - sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1) + # sin_yi = tf.math.sqrt(1-cos_yi**2) + sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1) - # cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m) - cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m + # cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m) + cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m - cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm) + cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm) - # Preparing the hot-output - one_hot = tf.one_hot( - tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask" - ) + # Preparing the hot-output + one_hot = tf.one_hot( + tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask" + ) - logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi) - logits = self.s * logits + logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi) + logits = self.s * logits + else: + logits = tf.matmul(X, self.W) return logits @@ -136,6 +151,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer): # Getting the angle theta = tf.math.acos(cos_yi) + theta = tf.clip_by_value( + theta, -1.0 + tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon() + ) cos_yi_m = tf.math.cos(self.m1 * theta + self.m2) - self.m3 @@ -146,8 +164,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer): tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask" ) + one_hot = tf.cast(one_hot, cos_yi_m.dtype) + logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi) logits = self.s * logits - return logits -- GitLab