From b0ef4f7488db7468074c1b85e3998c403e394ae6 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Mon, 12 Aug 2019 16:07:05 +0200 Subject: [PATCH] Pushed some fixes in the method --- bob/learn/activation/activation.py | 17 +++++++++-------- bob/learn/activation/test.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/bob/learn/activation/activation.py b/bob/learn/activation/activation.py index 9a9b6ff..9fc8df3 100644 --- a/bob/learn/activation/activation.py +++ b/bob/learn/activation/activation.py @@ -19,11 +19,12 @@ class Activation(_Activation_C): """ return {"id": self.unique_identifier()} - def from_dict(input_dict): + @classmethod + def from_dict(cls, input_dict): """ Loads itself from a python dict :py:class:`dict` """ - pass + return cls() class Logistic(_Logistic_C, Activation): @@ -42,15 +43,15 @@ class Linear(_Linear_C, Activation): def to_dict(self): return {"id": self.unique_identifier(), "C": self.C} - @staticmethod - def from_dict(input_dict): + @classmethod + def from_dict(cls, input_dict): """ Loads itself from a python dict :py:class:`dict` """ if "C" in input_dict: C = float(input_dict["C"]) - return Linear(C=C) + return cls(C=C) else: raise ValueError("Missing parameter `C` in `input_dict`") @@ -59,8 +60,8 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation): def to_dict(self): return {"id": self.unique_identifier(), "C": self.C, "M": self.M} - @staticmethod - def from_dict(input_dict): + @classmethod + def from_dict(cls, input_dict): """ Loads itself from a python dict :py:class:`dict` """ @@ -75,4 +76,4 @@ class MultipliedHyperbolicTangent(_MultipliedHyperbolicTangent_C, Activation): else: raise ValueError("Missing parameter `M` in `input_dict`") - return MultipliedHyperbolicTangent(C=C, M=M) + return cls(C=C, M=M) diff --git a/bob/learn/activation/test.py b/bob/learn/activation/test.py index 960ec4c..7143520 100644 --- a/bob/learn/activation/test.py +++ b/bob/learn/activation/test.py @@ -409,6 +409,19 @@ def test_to_dict(): def test_from_dict(): + # The first 3 tests don't make much sense, but I'm testing them anyways + input_dict = {"id": "bob.learn.activation.Activation.Logistic"} + logistic = Logistic.from_dict(input_dict) + assert isinstance(logistic, Logistic) + + input_dict = {"id": "bob.learn.activation.Activation.HyperbolicTangent"} + hyperbolic_tangent = HyperbolicTangent.from_dict(input_dict) + assert isinstance(hyperbolic_tangent, HyperbolicTangent) + + input_dict = {"id": "bob.learn.activation.Activation.Identity"} + identity = Identity.from_dict(input_dict) + assert isinstance(identity, Identity) + input_dict = {"id": "bob.learn.activation.Activation.Linear", "C": 2.0} linear = Linear.from_dict(input_dict) assert linear.C == 2 -- GitLab