diff --git a/bob/learn/tensorflow/models/arcface.py b/bob/learn/tensorflow/models/arcface.py
index 64dc56bf19b84c693ac291160e4bece528f24005..27cf85f4574ca0c4335310b3ec58803ec6a638c0 100644
--- a/bob/learn/tensorflow/models/arcface.py
+++ b/bob/learn/tensorflow/models/arcface.py
@@ -28,9 +28,17 @@ class ArcFaceLayer(tf.keras.layers.Layer):
"""
def __init__(
- self, n_classes, s=30, m=0.5, arc=True, kernel_initializer=None, **kwargs
+ # don't forget to fix get_config when you change init params
+ self,
+ n_classes,
+ s=30,
+ m=0.5,
+ arc=True,
+ kernel_initializer=None,
+ name="arc_face_logits",
+ **kwargs,
):
- super(ArcFaceLayer, self).__init__(name="arc_face_logits", **kwargs)
+ super().__init__(name=name, **kwargs)
self.n_classes = n_classes
self.s = s
self.arc = arc
@@ -84,6 +92,21 @@ class ArcFaceLayer(tf.keras.layers.Layer):
return logits
+ def get_config(self):
+ config = dict(super().get_config())
+ config.update(
+ {
+ "n_classes": self.n_classes,
+ "s": self.s,
+ "arc": self.arc,
+ "m": self.m,
+ "kernel_initializer": tf.keras.initializers.serialize(
+ self.kernel_initializer
+ ),
+ }
+ )
+ return config
+
class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
"""
@@ -94,8 +117,17 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
:math:`s(cos(m_1\\theta_i + m_2) -m_3`
"""
- def __init__(self, n_classes=10, s=30, m1=0.5, m2=0.5, m3=0.5):
- super(ArcFaceLayer3Penalties, self).__init__(name="arc_face_logits")
+ def __init__(
+ self,
+ n_classes=10,
+ s=30,
+ m1=0.5,
+ m2=0.5,
+ m3=0.5,
+ name="arc_face_logits",
+ **kwargs,
+ ):
+ super().__init__(name=name, **kwargs)
self.n_classes = n_classes
self.s = s
self.m1 = m1
@@ -138,3 +170,16 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
logits = self.s * logits
return logits
+
+ def get_config(self):
+ config = dict(super().get_config())
+ config.update(
+ {
+ "n_classes": self.n_classes,
+ "s": self.s,
+ "m1": self.m1,
+ "m2": self.m2,
+ "m3": self.m3,
+ }
+ )
+ return config