Commit 91613d6f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

implement get_config for arcface layers

parent 082fa660
......@@ -28,9 +28,17 @@ class ArcFaceLayer(tf.keras.layers.Layer):
"""
def __init__(
self, n_classes, s=30, m=0.5, arc=True, kernel_initializer=None, **kwargs
# don't forget to fix get_config when you change init params
self,
n_classes,
s=30,
m=0.5,
arc=True,
kernel_initializer=None,
name="arc_face_logits",
**kwargs,
):
super(ArcFaceLayer, self).__init__(name="arc_face_logits", **kwargs)
super().__init__(name=name, **kwargs)
self.n_classes = n_classes
self.s = s
self.arc = arc
......@@ -84,6 +92,21 @@ class ArcFaceLayer(tf.keras.layers.Layer):
return logits
def get_config(self):
config = dict(super().get_config())
config.update(
{
"n_classes": self.n_classes,
"s": self.s,
"arc": self.arc,
"m": self.m,
"kernel_initializer": tf.keras.initializers.serialize(
self.kernel_initializer
),
}
)
return config
class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
"""
......@@ -94,8 +117,17 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
:math:`s(cos(m_1\\theta_i + m_2) -m_3`
"""
def __init__(self, n_classes=10, s=30, m1=0.5, m2=0.5, m3=0.5):
super(ArcFaceLayer3Penalties, self).__init__(name="arc_face_logits")
def __init__(
self,
n_classes=10,
s=30,
m1=0.5,
m2=0.5,
m3=0.5,
name="arc_face_logits",
**kwargs,
):
super().__init__(name=name, **kwargs)
self.n_classes = n_classes
self.s = s
self.m1 = m1
......@@ -138,3 +170,16 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
logits = self.s * logits
return logits
def get_config(self):
config = dict(super().get_config())
config.update(
{
"n_classes": self.n_classes,
"s": self.s,
"m1": self.m1,
"m2": self.m2,
"m3": self.m3,
}
)
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment