Skip to content
Snippets Groups Projects
Commit b0ef4f74 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Pushed some fixes in the method

parent 6404e6c7
Branches
No related tags found
1 merge request!6WIP: First attempt to approach the issue bob.bio.base#106
Pipeline #32445 passed
......@@ -19,11 +19,12 @@ class Activation(_Activation_C):
"""
return {"id": self.unique_identifier()}
def from_dict(input_dict):
@classmethod
def from_dict(cls, input_dict):
"""
Loads itself from a python dict :py:class:`dict`
"""
pass
return cls()
class Logistic(_Logistic_C, Activation):
......@@ -42,15 +43,15 @@ class Linear(_Linear_C, Activation):
def to_dict(self):
return {"id": self.unique_identifier(), "C": self.C}
@staticmethod
def from_dict(input_dict):
@classmethod
def from_dict(cls, input_dict):
"""
Loads itself from a python dict :py:class:`dict`
"""
if "C" in input_dict:
C = float(input_dict["C"])
return Linear(C=C)
return cls(C=C)
else:
raise ValueError("Missing parameter `C` in `input_dict`")
......@@ -59,8 +60,8 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation):
def to_dict(self):
return {"id": self.unique_identifier(), "C": self.C, "M": self.M}
@staticmethod
def from_dict(input_dict):
@classmethod
def from_dict(cls, input_dict):
"""
Loads itself from a python dict :py:class:`dict`
"""
......@@ -75,4 +76,4 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation):
else:
raise ValueError("Missing parameter `M` in `input_dict`")
return MultipliedHyperbolicTangent(C=C, M=M)
return cls(C=C, M=M)
......@@ -409,6 +409,19 @@ def test_to_dict():
def test_from_dict():
# The first 3 tests don't make much sense, but I'm testing them anyways
input_dict = {"id": "bob.learn.activation.Activation.Logistic"}
logistic = Logistic.from_dict(input_dict)
assert isinstance(logistic, Logistic)
input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"}
hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict)
assert isinstance(hyperbolic_tangent, HyperbolicTangent)
input_dict = {"id": "bob.learn.activation.Activation.Identity"}
identity = Identity.from_dict(input_dict)
assert isinstance(identity, Identity)
input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0}
linear = Linear.from_dict(input_dict)
assert linear.C == 2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment