Skip to content
Snippets Groups Projects
Commit 92c9880e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'resnet101' into 'master'

Updated ARCFACE

See merge request !96
parents 4f80284c d255f06e
No related branches found
No related tags found
1 merge request!96Updated ARCFACE
Pipeline #51357 passed
......@@ -17,9 +17,18 @@ class ArcFaceModel(EmbeddingValidation):
loss = self.compiled_loss(
y, logits, sample_weight=None, regularization_losses=self.losses
)
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
reg_loss = tf.reduce_sum(self.losses)
total_loss = loss + reg_loss
trainable_vars = self.trainable_variables
self.optimizer.minimize(total_loss, trainable_vars, tape=tape)
self.compiled_metrics.update_state(y, logits, sample_weight=None)
tf.summary.scalar("arc_face_loss", data=loss, step=self._train_counter)
tf.summary.scalar("total_loss", data=total_loss, step=self._train_counter)
self.train_loss(loss)
return {m.name: m.result() for m in self.metrics + [self.train_loss]}
......@@ -55,12 +64,16 @@ class ArcFaceLayer(tf.keras.layers.Layer):
s: int
Scale
arc: bool
If `True`, uses arcface loss. If `False`, it's a regular dense layer
"""
def __init__(self, n_classes=10, s=30, m=0.5):
def __init__(self, n_classes=10, s=30, m=0.5, arc=True):
super(ArcFaceLayer, self).__init__(name="arc_face_logits")
self.n_classes = n_classes
self.s = s
self.arc = arc
self.m = m
def build(self, input_shape):
......@@ -75,29 +88,31 @@ class ArcFaceLayer(tf.keras.layers.Layer):
self.mm = tf.identity(math.sin(math.pi - self.m) * self.m)
def call(self, X, y, training=None):
if self.arc:
# normalize feature
X = tf.nn.l2_normalize(X, axis=1)
W = tf.nn.l2_normalize(self.W, axis=0)
# normalize feature
X = tf.nn.l2_normalize(X, axis=1)
W = tf.nn.l2_normalize(self.W, axis=0)
# cos between X and W
cos_yi = tf.matmul(X, W)
# cos between X and W
cos_yi = tf.matmul(X, W)
# sin_yi = tf.math.sqrt(1-cos_yi**2)
sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1)
# sin_yi = tf.math.sqrt(1-cos_yi**2)
sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1)
# cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m)
cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m
# cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m)
cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m
cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm)
cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm)
# Preparing the hot-output
one_hot = tf.one_hot(
tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
)
# Preparing the hot-output
one_hot = tf.one_hot(
tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
)
logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
logits = self.s * logits
logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
logits = self.s * logits
else:
logits = tf.matmul(X, self.W)
return logits
......@@ -136,6 +151,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
# Getting the angle
theta = tf.math.acos(cos_yi)
theta = tf.clip_by_value(
theta, -1.0 + tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()
)
cos_yi_m = tf.math.cos(self.m1 * theta + self.m2) - self.m3
......@@ -146,8 +164,9 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
)
one_hot = tf.cast(one_hot, cos_yi_m.dtype)
logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
logits = self.s * logits
return logits
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment