From 91613d6fd7d75dc26ef7ab995fdb4d3c890a534d Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 30 Jun 2021 14:55:05 +0200
Subject: [PATCH] implement get_config for arcface layers

---
 bob/learn/tensorflow/models/arcface.py | 53 ++++++++++++++++++++++++--
 1 file changed, 49 insertions(+), 4 deletions(-)

diff --git a/bob/learn/tensorflow/models/arcface.py b/bob/learn/tensorflow/models/arcface.py
index 64dc56bf..27cf85f4 100644
--- a/bob/learn/tensorflow/models/arcface.py
+++ b/bob/learn/tensorflow/models/arcface.py
@@ -28,9 +28,17 @@ class ArcFaceLayer(tf.keras.layers.Layer):
     """
 
     def __init__(
-        self, n_classes, s=30, m=0.5, arc=True, kernel_initializer=None, **kwargs
+        # don't forget to fix get_config when you change init params
+        self,
+        n_classes,
+        s=30,
+        m=0.5,
+        arc=True,
+        kernel_initializer=None,
+        name="arc_face_logits",
+        **kwargs,
     ):
-        super(ArcFaceLayer, self).__init__(name="arc_face_logits", **kwargs)
+        super().__init__(name=name, **kwargs)
         self.n_classes = n_classes
         self.s = s
         self.arc = arc
@@ -84,6 +92,21 @@ class ArcFaceLayer(tf.keras.layers.Layer):
 
         return logits
 
+    def get_config(self):
+        config = dict(super().get_config())
+        config.update(
+            {
+                "n_classes": self.n_classes,
+                "s": self.s,
+                "arc": self.arc,
+                "m": self.m,
+                "kernel_initializer": tf.keras.initializers.serialize(
+                    self.kernel_initializer
+                ),
+            }
+        )
+        return config
+
 
 class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
     """
@@ -94,8 +117,17 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
       :math:`s(cos(m_1\\theta_i + m_2) -m_3`
     """
 
-    def __init__(self, n_classes=10, s=30, m1=0.5, m2=0.5, m3=0.5):
-        super(ArcFaceLayer3Penalties, self).__init__(name="arc_face_logits")
+    def __init__(
+        self,
+        n_classes=10,
+        s=30,
+        m1=0.5,
+        m2=0.5,
+        m3=0.5,
+        name="arc_face_logits",
+        **kwargs,
+    ):
+        super().__init__(name=name, **kwargs)
         self.n_classes = n_classes
         self.s = s
         self.m1 = m1
@@ -138,3 +170,16 @@ class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
 
         logits = self.s * logits
         return logits
+
+    def get_config(self):
+        config = dict(super().get_config())
+        config.update(
+            {
+                "n_classes": self.n_classes,
+                "s": self.s,
+                "m1": self.m1,
+                "m2": self.m2,
+                "m3": self.m3,
+            }
+        )
+        return config
-- 
GitLab