Commit d255f06e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Updated ARCFACE

parent 324a941d
Pipeline #51355 passed with stage
in 4 minutes and 58 seconds
...@@ -17,9 +17,18 @@ class ArcFaceModel(EmbeddingValidation): ...@@ -17,9 +17,18 @@ class ArcFaceModel(EmbeddingValidation):
loss = self.compiled_loss( loss = self.compiled_loss(
y, logits, sample_weight=None, regularization_losses=self.losses y, logits, sample_weight=None, regularization_losses=self.losses
) )
self.optimizer.minimize(loss, self.trainable_variables, tape=tape) reg_loss = tf.reduce_sum(self.losses)
total_loss = loss + reg_loss
trainable_vars = self.trainable_variables
self.optimizer.minimize(total_loss, trainable_vars, tape=tape)
self.compiled_metrics.update_state(y, logits, sample_weight=None) self.compiled_metrics.update_state(y, logits, sample_weight=None)
tf.summary.scalar("arc_face_loss", data=loss, step=self._train_counter)
tf.summary.scalar("total_loss", data=total_loss, step=self._train_counter)
self.train_loss(loss) self.train_loss(loss)
return {m.name: m.result() for m in self.metrics + [self.train_loss]} return {m.name: m.result() for m in self.metrics + [self.train_loss]}
...@@ -55,12 +64,16 @@ class ArcFaceLayer(tf.keras.layers.Layer): ...@@ -55,12 +64,16 @@ class ArcFaceLayer(tf.keras.layers.Layer):
s: int s: int
Scale Scale
arc: bool
If `True`, uses arcface loss. If `False`, it's a regular dense layer
""" """
def __init__(self, n_classes=10, s=30, m=0.5): def __init__(self, n_classes=10, s=30, m=0.5, arc=True):
super(ArcFaceLayer, self).__init__(name="arc_face_logits") super(ArcFaceLayer, self).__init__(name="arc_face_logits")
self.n_classes = n_classes self.n_classes = n_classes
self.s = s self.s = s
self.arc = arc
self.m = m self.m = m
def build(self, input_shape): def build(self, input_shape):
...@@ -75,7 +88,7 @@ class ArcFaceLayer(tf.keras.layers.Layer): ...@@ -75,7 +88,7 @@ class ArcFaceLayer(tf.keras.layers.Layer):
self.mm = tf.identity(math.sin(math.pi - self.m) * self.m) self.mm = tf.identity(math.sin(math.pi - self.m) * self.m)
def call(self, X, y, training=None): def call(self, X, y, training=None):
if self.arc:
# normalize feature # normalize feature
X = tf.nn.l2_normalize(X, axis=1) X = tf.nn.l2_normalize(X, axis=1)
W = tf.nn.l2_normalize(self.W, axis=0) W = tf.nn.l2_normalize(self.W, axis=0)
...@@ -98,6 +111,8 @@ class ArcFaceLayer(tf.keras.layers.Layer): ...@@ -98,6 +111,8 @@ class ArcFaceLayer(tf.keras.layers.Layer):
logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi) logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
logits = self.s * logits logits = self.s * logits
else:
logits = tf.matmul(X, self.W)
return logits return logits
...@@ -136,6 +151,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer): ...@@ -136,6 +151,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
# Getting the angle # Getting the angle
theta = tf.math.acos(cos_yi) theta = tf.math.acos(cos_yi)
theta = tf.clip_by_value(
theta, -1.0 + tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()
)
cos_yi_m = tf.math.cos(self.m1 * theta + self.m2) - self.m3 cos_yi_m = tf.math.cos(self.m1 * theta + self.m2) - self.m3
...@@ -146,8 +164,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer): ...@@ -146,8 +164,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask" tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
) )
one_hot = tf.cast(one_hot, cos_yi_m.dtype)
logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi) logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
logits = self.s * logits logits = self.s * logits
return logits return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment