Pushed some fixes in the method

parent 6404e6c7
Pipeline #32445 passed with stage
in 19 minutes and 52 seconds
...@@ -19,11 +19,12 @@ class Activation(_Activation_C): ...@@ -19,11 +19,12 @@ class Activation(_Activation_C):
""" """
return {"id": self.unique_identifier()} return {"id": self.unique_identifier()}
def from_dict(input_dict): @classmethod
def from_dict(cls, input_dict):
""" """
Loads itself from a python dict :py:class:`dict` Loads itself from a python dict :py:class:`dict`
""" """
pass return cls()
class Logistic(_Logistic_C, Activation): class Logistic(_Logistic_C, Activation):
...@@ -42,15 +43,15 @@ class Linear(_Linear_C, Activation): ...@@ -42,15 +43,15 @@ class Linear(_Linear_C, Activation):
def to_dict(self): def to_dict(self):
return {"id": self.unique_identifier(), "C": self.C} return {"id": self.unique_identifier(), "C": self.C}
@staticmethod @classmethod
def from_dict(input_dict): def from_dict(cls, input_dict):
""" """
Loads itself from a python dict :py:class:`dict` Loads itself from a python dict :py:class:`dict`
""" """
if "C" in input_dict: if "C" in input_dict:
C = float(input_dict["C"]) C = float(input_dict["C"])
return Linear(C=C) return cls(C=C)
else: else:
raise ValueError("Missing parameter `C` in `input_dict`") raise ValueError("Missing parameter `C` in `input_dict`")
...@@ -59,8 +60,8 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation): ...@@ -59,8 +60,8 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation):
def to_dict(self): def to_dict(self):
return {"id": self.unique_identifier(), "C": self.C, "M": self.M} return {"id": self.unique_identifier(), "C": self.C, "M": self.M}
@staticmethod @classmethod
def from_dict(input_dict): def from_dict(cls, input_dict):
""" """
Loads itself from a python dict :py:class:`dict` Loads itself from a python dict :py:class:`dict`
""" """
...@@ -75,4 +76,4 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation): ...@@ -75,4 +76,4 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation):
else: else:
raise ValueError("Missing parameter `M` in `input_dict`") raise ValueError("Missing parameter `M` in `input_dict`")
return MultipliedHyperbolicTangent(C=C, M=M) return cls(C=C, M=M)
...@@ -409,6 +409,19 @@ def test_to_dict(): ...@@ -409,6 +409,19 @@ def test_to_dict():
def test_from_dict(): def test_from_dict():
# The first 3 tests don't make much sense, but I'm testing them anyways
input_dict = {"id": "bob.learn.activation.Activation.Logistic"}
logistic = Logistic.from_dict(input_dict)
assert isinstance(logistic, Logistic)
input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"}
hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict)
assert isinstance(hyperbolic_tangent, HyperbolicTangent)
input_dict = {"id": "bob.learn.activation.Activation.Identity"}
identity = Identity.from_dict(input_dict)
assert isinstance(identity, Identity)
input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0} input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0}
linear = Linear.from_dict(input_dict) linear = Linear.from_dict(input_dict)
assert linear.C == 2 assert linear.C == 2
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment