Updated from_dict function

parent b0ef4f74
Pipeline #32523 passed with stage
in 6 minutes and 8 seconds
......@@ -24,7 +24,37 @@ class Activation(_Activation_C):
"""
Loads itself from a python dict :py:class:`dict`
"""
return cls()
# TODO: Work it out a clean way to do this search between deprecated and new ids
# 0 (linear), 1 (tanh) or 2 (logistic)
## Deprecated loader
if input_dict["id"] == 0:
return Linear.from_dict(input_dict)
if input_dict["id"] == 1:
return HyperbolicTangent()
if input_dict["id"] == 2:
return Logistic()
# Newer loader
if "Activation.Linear" in input_dict["id"]:
return Linear.from_dict(input_dict)
if "Activation.HyperbolicTangent" in input_dict["id"]:
return HyperbolicTangent()
if "Activation.Logistic" in input_dict["id"]:
return Logistic()
if "Activation.Identity" in input_dict["id"]:
return Identity()
if "Activation.MultipliedHyperbolicTangent" in input_dict["id"]:
return MultipliedHyperbolicTangent.from_dict(input_dict)
raise ValueError(f"Value {input_dict} is invalid for `Activation`")
class Logistic(_Logistic_C, Activation):
......
......@@ -9,7 +9,15 @@
import numpy
import math
from . import Identity, Linear, Logistic, HyperbolicTangent, MultipliedHyperbolicTangent
from . import (
Identity,
Linear,
Logistic,
HyperbolicTangent,
MultipliedHyperbolicTangent,
Activation,
)
from nose.tools import assert_raises
def estimate_gradient(f, x, epsilon=1e-4, args=()):
......@@ -411,19 +419,19 @@ def test_from_dict():
# The first 3 tests don't make much sense, but I'm testing them anyways
input_dict = {"id": "bob.learn.activation.Activation.Logistic"}
logistic = Logistic.from_dict(input_dict)
logistic = Activation.from_dict(input_dict)
assert isinstance(logistic, Logistic)
input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"}
hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict)
hyperbolic_tangent = Activation.from_dict(input_dict)
assert isinstance(hyperbolic_tangent, HyperbolicTangent)
input_dict = {"id": "bob.learn.activation.Activation.Identity"}
identity = Identity.from_dict(input_dict)
identity = Activation.from_dict(input_dict)
assert isinstance(identity, Identity)
input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0}
linear = Linear.from_dict(input_dict)
linear = Activation.from_dict(input_dict)
assert linear.C == 2
input_dict = {
......@@ -431,6 +439,11 @@ def test_from_dict():
"C": 2.0,
"M": 3.0,
}
multiplied_hyperbolic_tangent = MultipliedHyperbolicTangent.from_dict(input_dict)
multiplied_hyperbolic_tangent = Activation.from_dict(input_dict)
assert multiplied_hyperbolic_tangent.C == 2.0
assert multiplied_hyperbolic_tangent.M == 3.0
with assert_raises(ValueError):
input_dict = {"id": "bob.learn.activation.Activation.Wrong"}
Activation.from_dict(input_dict)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment