Commit 2e56ab7b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

The whole arcface layer is 32bit in insightface pytorch code

parent 44d5075f
......@@ -27,19 +27,21 @@ class ArcFaceLayer(tf.keras.layers.Layer):
If `True`, uses arcface loss. If `False`, it's a regular dense layer
"""
def __init__(self, n_classes, s=30, m=0.5, arc=True):
super(ArcFaceLayer, self).__init__(name="arc_face_logits")
def __init__(
self, n_classes, s=30, m=0.5, arc=True, kernel_initializer=None, **kwargs
):
super(ArcFaceLayer, self).__init__(name="arc_face_logits", **kwargs)
self.n_classes = n_classes
self.s = s
self.arc = arc
self.m = m
self.act32bit = tf.keras.layers.Activation("linear", dtype="float32")
self.kernel_initializer = kernel_initializer
def build(self, input_shape):
super(ArcFaceLayer, self).build(input_shape[0])
shape = [input_shape[-1], self.n_classes]
self.W = self.add_variable("W", shape=shape)
self.W = self.add_weight("W", shape=shape, initializer=self.kernel_initializer)
self.cos_m = tf.identity(math.cos(self.m), name="cos_m")
self.sin_m = tf.identity(math.sin(self.m), name="sin_m")
......@@ -80,7 +82,6 @@ class ArcFaceLayer(tf.keras.layers.Layer):
else:
logits = tf.matmul(X, self.W)
logits = self.act32bit(logits)
return logits
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment