Skip to content
Snippets Groups Projects
Commit 2e56ab7b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

The whole arcface layer is 32bit in insightface pytorch code

parent 44d5075f
No related branches found
No related tags found
No related merge requests found
......@@ -27,19 +27,21 @@ class ArcFaceLayer(tf.keras.layers.Layer):
If `True`, uses arcface loss. If `False`, it's a regular dense layer
"""
def __init__(self, n_classes, s=30, m=0.5, arc=True):
super(ArcFaceLayer, self).__init__(name="arc_face_logits")
def __init__(
self, n_classes, s=30, m=0.5, arc=True, kernel_initializer=None, **kwargs
):
super(ArcFaceLayer, self).__init__(name="arc_face_logits", **kwargs)
self.n_classes = n_classes
self.s = s
self.arc = arc
self.m = m
self.act32bit = tf.keras.layers.Activation("linear", dtype="float32")
self.kernel_initializer = kernel_initializer
def build(self, input_shape):
super(ArcFaceLayer, self).build(input_shape[0])
shape = [input_shape[-1], self.n_classes]
self.W = self.add_variable("W", shape=shape)
self.W = self.add_weight("W", shape=shape, initializer=self.kernel_initializer)
self.cos_m = tf.identity(math.cos(self.m), name="cos_m")
self.sin_m = tf.identity(math.sin(self.m), name="sin_m")
......@@ -80,7 +82,6 @@ class ArcFaceLayer(tf.keras.layers.Layer):
else:
logits = tf.matmul(X, self.W)
logits = self.act32bit(logits)
return logits
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment