From 64bb75f829bba3387e68c6f524dfd8b9752eaee0 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 15 Aug 2019 10:06:24 +0200
Subject: [PATCH] Updated from_dict function

---
 bob/learn/activation/activation.py | 32 +++++++++++++++++++++++++++++-
 bob/learn/activation/test.py       | 25 +++++++++++++++++------
 2 files changed, 50 insertions(+), 7 deletions(-)

diff --git a/bob/learn/activation/activation.py b/bob/learn/activation/activation.py
index 9fc8df3..a019f8e 100644
--- a/bob/learn/activation/activation.py
+++ b/bob/learn/activation/activation.py
@@ -24,7 +24,37 @@ class Activation(_Activation_C):
         """
         Loads itself from a python dict :py:class:`dict`
         """
-        return cls()
+
+        # TODO: Work it out a clean way to do this search between deprecated and new ids
+        #  0 (linear), 1 (tanh) or 2 (logistic)
+
+        ## Deprecated loader
+        if input_dict["id"] == 0:
+            return Linear.from_dict(input_dict)
+
+        if input_dict["id"] == 1:
+            return HyperbolicTangent()
+
+        if input_dict["id"] == 2:
+            return Logistic()
+
+        # Newer loader
+        if "Activation.Linear" in input_dict["id"]:
+            return Linear.from_dict(input_dict)
+
+        if "Activation.HyperbolicTangent" in input_dict["id"]:
+            return HyperbolicTangent()
+
+        if "Activation.Logistic" in input_dict["id"]:
+            return Logistic()
+
+        if "Activation.Identity" in input_dict["id"]:
+            return Identity()
+
+        if "Activation.MultipliedHyperbolicTangent" in input_dict["id"]:
+            return MultipliedHyperbolicTangent.from_dict(input_dict)
+
+        raise ValueError(f"Value {input_dict} is invalid for `Activation`")
 
 
 class Logistic(_Logistic_C, Activation):
diff --git a/bob/learn/activation/test.py b/bob/learn/activation/test.py
index 7143520..f4b3918 100644
--- a/bob/learn/activation/test.py
+++ b/bob/learn/activation/test.py
@@ -9,7 +9,15 @@
 import numpy
 import math
 
-from . import Identity, Linear, Logistic, HyperbolicTangent, MultipliedHyperbolicTangent
+from . import (
+    Identity,
+    Linear,
+    Logistic,
+    HyperbolicTangent,
+    MultipliedHyperbolicTangent,
+    Activation,
+)
+from nose.tools import assert_raises
 
 
 def estimate_gradient(f, x, epsilon=1e-4, args=()):
@@ -411,19 +419,19 @@ def test_from_dict():
 
     # The first 3 tests don't make much sense, but I'm testing them anyways
     input_dict = {"id": "bob.learn.activation.Activation.Logistic"}
-    logistic = Logistic.from_dict(input_dict)
+    logistic = Activation.from_dict(input_dict)
     assert isinstance(logistic, Logistic)
 
     input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"}
-    hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict)
+    hyperbolic_tangent = Activation.from_dict(input_dict)
     assert isinstance(hyperbolic_tangent, HyperbolicTangent)
 
     input_dict = {"id": "bob.learn.activation.Activation.Identity"}
-    identity = Identity.from_dict(input_dict)
+    identity = Activation.from_dict(input_dict)
     assert isinstance(identity, Identity)
 
     input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0}
-    linear = Linear.from_dict(input_dict)
+    linear = Activation.from_dict(input_dict)
     assert linear.C == 2
 
     input_dict = {
@@ -431,6 +439,11 @@ def test_from_dict():
         "C": 2.0,
         "M": 3.0,
     }
-    multiplied_hyperbolic_tangent = MultipliedHyperbolicTangent.from_dict(input_dict)
+
+    multiplied_hyperbolic_tangent = Activation.from_dict(input_dict)
     assert multiplied_hyperbolic_tangent.C == 2.0
     assert multiplied_hyperbolic_tangent.M == 3.0
+
+    with assert_raises(ValueError):
+        input_dict = {"id": "bob.learn.activation.Activation.Wrong"}
+        Activation.from_dict(input_dict)
-- 
GitLab