diff --git a/bob/learn/activation/activation.py b/bob/learn/activation/activation.py index 9fc8df3ae51163cd3608c04560d368c5128614f2..a019f8e7af51bb810f6a573b4f3f22ac915d69c0 100644 --- a/bob/learn/activation/activation.py +++ b/bob/learn/activation/activation.py @@ -24,7 +24,37 @@ class Activation(_Activation_C): """ Loads itself from a python dict :py:class:`dict` """ - return cls() + + # TODO: Work it out a clean way to do this search between deprecated and new ids + # 0 (linear), 1 (tanh) or 2 (logistic) + + ## Deprecated loader + if input_dict["id"] == 0: + return Linear.from_dict(input_dict) + + if input_dict["id"] == 1: + return HyperbolicTangent() + + if input_dict["id"] == 2: + return Logistic() + + # Newer loader + if "Activation.Linear" in input_dict["id"]: + return Linear.from_dict(input_dict) + + if "Activation.HyperbolicTangent" in input_dict["id"]: + return HyperbolicTangent() + + if "Activation.Logistic" in input_dict["id"]: + return Logistic() + + if "Activation.Identity" in input_dict["id"]: + return Identity() + + if "Activation.MultipliedHyperbolicTangent" in input_dict["id"]: + return MultipliedHyperbolicTangent.from_dict(input_dict) + + raise ValueError(f"Value {input_dict} is invalid for `Activation`") class Logistic(_Logistic_C, Activation): diff --git a/bob/learn/activation/test.py b/bob/learn/activation/test.py index 7143520781fab837dd6f8d952224ba8040761068..f4b39182afca06e5d3c1433dc9193380a4bbcda3 100644 --- a/bob/learn/activation/test.py +++ b/bob/learn/activation/test.py @@ -9,7 +9,15 @@ import numpy import math -from . import Identity, Linear, Logistic, HyperbolicTangent, MultipliedHyperbolicTangent +from . import ( + Identity, + Linear, + Logistic, + HyperbolicTangent, + MultipliedHyperbolicTangent, + Activation, +) +from nose.tools import assert_raises def estimate_gradient(f, x, epsilon=1e-4, args=()): @@ -411,19 +419,19 @@ def test_from_dict(): # The first 3 tests don't make much sense, but I'm testing them anyways input_dict = {"id": "bob.learn.activation.Activation.Logistic"} - logistic = Logistic.from_dict(input_dict) + logistic = Activation.from_dict(input_dict) assert isinstance(logistic, Logistic) input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"} - hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict) + hyperbolic_tangent = Activation.from_dict(input_dict) assert isinstance(hyperbolic_tangent, HyperbolicTangent) input_dict = {"id": "bob.learn.activation.Activation.Identity"} - identity = Identity.from_dict(input_dict) + identity = Activation.from_dict(input_dict) assert isinstance(identity, Identity) input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0} - linear = Linear.from_dict(input_dict) + linear = Activation.from_dict(input_dict) assert linear.C == 2 input_dict = { @@ -431,6 +439,11 @@ def test_from_dict(): "C": 2.0, "M": 3.0, } - multiplied_hyperbolic_tangent = MultipliedHyperbolicTangent.from_dict(input_dict) + + multiplied_hyperbolic_tangent = Activation.from_dict(input_dict) assert multiplied_hyperbolic_tangent.C == 2.0 assert multiplied_hyperbolic_tangent.M == 3.0 + + with assert_raises(ValueError): + input_dict = {"id": "bob.learn.activation.Activation.Wrong"} + Activation.from_dict(input_dict)