Skip to content
Snippets Groups Projects
Commit 64bb75f8 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Updated from_dict function

parent b0ef4f74
No related branches found
No related tags found
1 merge request!6WIP: First attempt to approach the issue bob.bio.base#106
Pipeline #32523 passed
...@@ -24,7 +24,37 @@ class Activation(_Activation_C): ...@@ -24,7 +24,37 @@ class Activation(_Activation_C):
""" """
Loads itself from a python dict :py:class:`dict` Loads itself from a python dict :py:class:`dict`
""" """
return cls()
# TODO: Work it out a clean way to do this search between deprecated and new ids
# 0 (linear), 1 (tanh) or 2 (logistic)
## Deprecated loader
if input_dict["id"] == 0:
return Linear.from_dict(input_dict)
if input_dict["id"] == 1:
return HyperbolicTangent()
if input_dict["id"] == 2:
return Logistic()
# Newer loader
if "Activation.Linear" in input_dict["id"]:
return Linear.from_dict(input_dict)
if "Activation.HyperbolicTangent" in input_dict["id"]:
return HyperbolicTangent()
if "Activation.Logistic" in input_dict["id"]:
return Logistic()
if "Activation.Identity" in input_dict["id"]:
return Identity()
if "Activation.MultipliedHyperbolicTangent" in input_dict["id"]:
return MultipliedHyperbolicTangent.from_dict(input_dict)
raise ValueError(f"Value {input_dict} is invalid for `Activation`")
class Logistic(_Logistic_C, Activation): class Logistic(_Logistic_C, Activation):
......
...@@ -9,7 +9,15 @@ ...@@ -9,7 +9,15 @@
import numpy import numpy
import math import math
from . import Identity, Linear, Logistic, HyperbolicTangent, MultipliedHyperbolicTangent from . import (
Identity,
Linear,
Logistic,
HyperbolicTangent,
MultipliedHyperbolicTangent,
Activation,
)
from nose.tools import assert_raises
def estimate_gradient(f, x, epsilon=1e-4, args=()): def estimate_gradient(f, x, epsilon=1e-4, args=()):
...@@ -411,19 +419,19 @@ def test_from_dict(): ...@@ -411,19 +419,19 @@ def test_from_dict():
# The first 3 tests don't make much sense, but I'm testing them anyways # The first 3 tests don't make much sense, but I'm testing them anyways
input_dict = {"id": "bob.learn.activation.Activation.Logistic"} input_dict = {"id": "bob.learn.activation.Activation.Logistic"}
logistic = Logistic.from_dict(input_dict) logistic = Activation.from_dict(input_dict)
assert isinstance(logistic, Logistic) assert isinstance(logistic, Logistic)
input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"} input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"}
hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict) hyperbolic_tangent = Activation.from_dict(input_dict)
assert isinstance(hyperbolic_tangent, HyperbolicTangent) assert isinstance(hyperbolic_tangent, HyperbolicTangent)
input_dict = {"id": "bob.learn.activation.Activation.Identity"} input_dict = {"id": "bob.learn.activation.Activation.Identity"}
identity = Identity.from_dict(input_dict) identity = Activation.from_dict(input_dict)
assert isinstance(identity, Identity) assert isinstance(identity, Identity)
input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0} input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0}
linear = Linear.from_dict(input_dict) linear = Activation.from_dict(input_dict)
assert linear.C == 2 assert linear.C == 2
input_dict = { input_dict = {
...@@ -431,6 +439,11 @@ def test_from_dict(): ...@@ -431,6 +439,11 @@ def test_from_dict():
"C": 2.0, "C": 2.0,
"M": 3.0, "M": 3.0,
} }
multiplied_hyperbolic_tangent = MultipliedHyperbolicTangent.from_dict(input_dict)
multiplied_hyperbolic_tangent = Activation.from_dict(input_dict)
assert multiplied_hyperbolic_tangent.C == 2.0 assert multiplied_hyperbolic_tangent.C == 2.0
assert multiplied_hyperbolic_tangent.M == 3.0 assert multiplied_hyperbolic_tangent.M == 3.0
with assert_raises(ValueError):
input_dict = {"id": "bob.learn.activation.Activation.Wrong"}
Activation.from_dict(input_dict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment