From 057e58d49ef8663b5ae52b08a67d8359957a144e Mon Sep 17 00:00:00 2001 From: Manuel Guenther <manuel.guenther@idiap.ch> Date: Tue, 19 Aug 2014 15:25:47 +0200 Subject: [PATCH] Updated unique names of activation functions; registered all activation functions. --- .../activation/cpp/ActivationRegistry.cpp | 41 +++++++++++++++++++ .../include/bob.learn.activation/Activation.h | 10 ++--- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/bob/learn/activation/cpp/ActivationRegistry.cpp b/bob/learn/activation/cpp/ActivationRegistry.cpp index 7a650fe..275524e 100644 --- a/bob/learn/activation/cpp/ActivationRegistry.cpp +++ b/bob/learn/activation/cpp/ActivationRegistry.cpp @@ -9,6 +9,7 @@ #include <bob.learn.activation/Activation.h> #include <boost/make_shared.hpp> +#include <bob.core/logging.h> boost::shared_ptr<bob::learn::activation::ActivationRegistry> bob::learn::activation::ActivationRegistry::instance() { static boost::shared_ptr<bob::learn::activation::ActivationRegistry> s_instance(new ActivationRegistry()); @@ -69,6 +70,20 @@ bob::learn::activation::activation_factory_t bob::learn::activation::ActivationR auto it = s_id2factory.find(id); + if (it == s_id2factory.end()) { + // try to convert the old "machine" name into the new "learn.activation" name + auto i = id.find("machine"); + if (i != std::string::npos){ + std::string tid = id; + tid.replace(i, 7, "learn.activation"); + + it = s_id2factory.find(tid); + if (it != s_id2factory.end()) { + bob::core::warn << "Using the old name of the activation function '" << id << "' is deprecated. Please use '" << tid << "' instead!"; + } + } + } + if (it == s_id2factory.end()) { boost::format m("unregistered activation function: %s"); m % id; @@ -79,3 +94,29 @@ bob::learn::activation::activation_factory_t bob::learn::activation::ActivationR } +/** + * A generalized registration mechanism for all classes above + */ +template <typename T> struct register_activation { + + static boost::shared_ptr<bob::learn::activation::Activation> factory (bob::io::base::HDF5File& f) { + auto retval = boost::make_shared<T>(); + retval->load(f); + return retval; + } + + register_activation() { + T obj; + bob::learn::activation::ActivationRegistry::instance()->registerActivation(obj.unique_identifier(), register_activation<T>::factory); + } + +}; + +// register all extensions +static register_activation<bob::learn::activation::IdentityActivation> _identity_act_reg; +static register_activation<bob::learn::activation::LinearActivation> _linear_act_reg; +static register_activation<bob::learn::activation::HyperbolicTangentActivation> _tanh_act_reg; +static register_activation<bob::learn::activation::MultipliedHyperbolicTangentActivation> _multanh_act_reg; +static register_activation<bob::learn::activation::LogisticActivation> _logistic_act_reg; + + diff --git a/bob/learn/activation/include/bob.learn.activation/Activation.h b/bob/learn/activation/include/bob.learn.activation/Activation.h index 53142e7..643a1ce 100644 --- a/bob/learn/activation/include/bob.learn.activation/Activation.h +++ b/bob/learn/activation/include/bob.learn.activation/Activation.h @@ -92,7 +92,7 @@ namespace bob { namespace learn { namespace activation { virtual double f (double z) const { return z; } virtual double f_prime (double z) const { return 1.; } virtual double f_prime_from_f (double a) const { return 1.; } - virtual std::string unique_identifier() const { return "bob.machine.Activation.Identity"; } + virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Identity"; } virtual std::string str() const { return "f(z) = z"; } }; @@ -112,7 +112,7 @@ namespace bob { namespace learn { namespace activation { double C() const { return m_C; } virtual void save(bob::io::base::HDF5File& f) const { Activation::save(f); f.set("C", m_C); } virtual void load(bob::io::base::HDF5File& f) { m_C = f.read<double>("C"); } - virtual std::string unique_identifier() const { return "bob.machine.Activation.Linear"; } + virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Linear"; } virtual std::string str() const { return (boost::format("f(z) = %.5e * z") % m_C).str(); } private: // representation @@ -131,7 +131,7 @@ namespace bob { namespace learn { namespace activation { virtual ~HyperbolicTangentActivation() {} virtual double f (double z) const { return std::tanh(z); } virtual double f_prime_from_f (double a) const { return (1. - (a*a)); } - virtual std::string unique_identifier() const { return "bob.machine.Activation.HyperbolicTangent"; } + virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.HyperbolicTangent"; } virtual std::string str() const { return "f(z) = tanh(z)"; } }; @@ -151,7 +151,7 @@ namespace bob { namespace learn { namespace activation { double M() const { return m_M; } virtual void save(bob::io::base::HDF5File& f) const { Activation::save(f); f.set("C", m_C); f.set("M", m_C); } virtual void load(bob::io::base::HDF5File& f) {m_C = f.read<double>("C"); m_M = f.read<double>("M"); } - virtual std::string unique_identifier() const { return "bob.machine.Activation.MultipliedHyperbolicTangent"; } + virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.MultipliedHyperbolicTangent"; } virtual std::string str() const { return (boost::format("f(z) = %.5e * tanh(%.5e * z)") % m_C % m_M).str(); } private: // representation @@ -171,7 +171,7 @@ namespace bob { namespace learn { namespace activation { virtual ~LogisticActivation() {} virtual double f (double z) const { return 1. / ( 1. + std::exp(-z) ); } virtual double f_prime_from_f (double a) const { return a * (1. - a); } - virtual std::string unique_identifier() const { return "bob.machine.Activation.Logistic"; } + virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Logistic"; } virtual std::string str() const { return "f(z) = 1./(1. + e^-z)"; } }; -- GitLab