From 64bb75f829bba3387e68c6f524dfd8b9752eaee0 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Thu, 15 Aug 2019 10:06:24 +0200 Subject: [PATCH] Updated from_dict function --- bob/learn/activation/activation.py | 32 +++++++++++++++++++++++++++++- bob/learn/activation/test.py | 25 +++++++++++++++++------ 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/bob/learn/activation/activation.py b/bob/learn/activation/activation.py index 9fc8df3..a019f8e 100644 --- a/bob/learn/activation/activation.py +++ b/bob/learn/activation/activation.py @@ -24,7 +24,37 @@ class Activation(_Activation_C): """ Loads itself from a python dict :py:class:`dict` """ - return cls() + + # TODO: Work it out a clean way to do this search between deprecated and new ids + # 0 (linear), 1 (tanh) or 2 (logistic) + + ## Deprecated loader + if input_dict["id"] == 0: + return Linear.from_dict(input_dict) + + if input_dict["id"] == 1: + return HyperbolicTangent() + + if input_dict["id"] == 2: + return Logistic() + + # Newer loader + if "Activation.Linear" in input_dict["id"]: + return Linear.from_dict(input_dict) + + if "Activation.HyperbolicTangent" in input_dict["id"]: + return HyperbolicTangent() + + if "Activation.Logistic" in input_dict["id"]: + return Logistic() + + if "Activation.Identity" in input_dict["id"]: + return Identity() + + if "Activation.MultipliedHyperbolicTangent" in input_dict["id"]: + return MultipliedHyperbolicTangent.from_dict(input_dict) + + raise ValueError(f"Value {input_dict} is invalid for `Activation`") class Logistic(_Logistic_C, Activation): diff --git a/bob/learn/activation/test.py b/bob/learn/activation/test.py index 7143520..f4b3918 100644 --- a/bob/learn/activation/test.py +++ b/bob/learn/activation/test.py @@ -9,7 +9,15 @@ import numpy import math -from . import Identity, Linear, Logistic, HyperbolicTangent, MultipliedHyperbolicTangent +from . import ( + Identity, + Linear, + Logistic, + HyperbolicTangent, + MultipliedHyperbolicTangent, + Activation, +) +from nose.tools import assert_raises def estimate_gradient(f, x, epsilon=1e-4, args=()): @@ -411,19 +419,19 @@ def test_from_dict(): # The first 3 tests don't make much sense, but I'm testing them anyways input_dict = {"id": "bob.learn.activation.Activation.Logistic"} - logistic = Logistic.from_dict(input_dict) + logistic = Activation.from_dict(input_dict) assert isinstance(logistic, Logistic) input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"} - hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict) + hyperbolic_tangent = Activation.from_dict(input_dict) assert isinstance(hyperbolic_tangent, HyperbolicTangent) input_dict = {"id": "bob.learn.activation.Activation.Identity"} - identity = Identity.from_dict(input_dict) + identity = Activation.from_dict(input_dict) assert isinstance(identity, Identity) input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0} - linear = Linear.from_dict(input_dict) + linear = Activation.from_dict(input_dict) assert linear.C == 2 input_dict = { @@ -431,6 +439,11 @@ def test_from_dict(): "C": 2.0, "M": 3.0, } - multiplied_hyperbolic_tangent = MultipliedHyperbolicTangent.from_dict(input_dict) + + multiplied_hyperbolic_tangent = Activation.from_dict(input_dict) assert multiplied_hyperbolic_tangent.C == 2.0 assert multiplied_hyperbolic_tangent.M == 3.0 + + with assert_raises(ValueError): + input_dict = {"id": "bob.learn.activation.Activation.Wrong"} + Activation.from_dict(input_dict) -- GitLab