From b0ef4f7488db7468074c1b85e3998c403e394ae6 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Mon, 12 Aug 2019 16:07:05 +0200
Subject: [PATCH] Pushed some fixes in the  method

---
 bob/learn/activation/activation.py | 17 +++++++++--------
 bob/learn/activation/test.py       | 13 +++++++++++++
 2 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/bob/learn/activation/activation.py b/bob/learn/activation/activation.py
index 9a9b6ff..9fc8df3 100644
--- a/bob/learn/activation/activation.py
+++ b/bob/learn/activation/activation.py
@@ -19,11 +19,12 @@ class Activation(_Activation_C):
         """
         return {"id": self.unique_identifier()}
 
-    def from_dict(input_dict):
+    @classmethod
+    def from_dict(cls, input_dict):
         """
         Loads itself from a python dict :py:class:`dict`
         """
-        pass
+        return cls()
 
 
 class Logistic(_Logistic_C, Activation):
@@ -42,15 +43,15 @@ class Linear(_Linear_C, Activation):
     def to_dict(self):
         return {"id": self.unique_identifier(), "C": self.C}
 
-    @staticmethod
-    def from_dict(input_dict):
+    @classmethod
+    def from_dict(cls, input_dict):
         """
         Loads itself from a python dict :py:class:`dict`
         """
 
         if "C" in input_dict:
             C = float(input_dict["C"])
-            return Linear(C=C)
+            return cls(C=C)
         else:
             raise ValueError("Missing parameter `C` in `input_dict`")
 
@@ -59,8 +60,8 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation):
     def to_dict(self):
         return {"id": self.unique_identifier(), "C": self.C, "M": self.M}
 
-    @staticmethod
-    def from_dict(input_dict):
+    @classmethod
+    def from_dict(cls, input_dict):
         """
         Loads itself from a python dict :py:class:`dict`
         """
@@ -75,4 +76,4 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation):
         else:
             raise ValueError("Missing parameter `M` in `input_dict`")
 
-        return MultipliedHyperbolicTangent(C=C, M=M)
+        return cls(C=C, M=M)
diff --git a/bob/learn/activation/test.py b/bob/learn/activation/test.py
index 960ec4c..7143520 100644
--- a/bob/learn/activation/test.py
+++ b/bob/learn/activation/test.py
@@ -409,6 +409,19 @@ def test_to_dict():
 
 def test_from_dict():
 
+    # The first 3 tests don't make much sense, but I'm testing them anyways
+    input_dict = {"id": "bob.learn.activation.Activation.Logistic"}
+    logistic = Logistic.from_dict(input_dict)
+    assert isinstance(logistic, Logistic)
+
+    input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"}
+    hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict)
+    assert isinstance(hyperbolic_tangent, HyperbolicTangent)
+
+    input_dict = {"id": "bob.learn.activation.Activation.Identity"}
+    identity = Identity.from_dict(input_dict)
+    assert isinstance(identity, Identity)
+
     input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0}
     linear = Linear.from_dict(input_dict)
     assert linear.C == 2
-- 
GitLab