Skip to content
Snippets Groups Projects
Commit 057e58d4 authored by Manuel Günther's avatar Manuel Günther
Browse files

Updated unique names of activation functions; registered all activation functions.

parent 653163bb
Branches
Tags
No related merge requests found
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <bob.learn.activation/Activation.h> #include <bob.learn.activation/Activation.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <bob.core/logging.h>
boost::shared_ptr<bob::learn::activation::ActivationRegistry> bob::learn::activation::ActivationRegistry::instance() { boost::shared_ptr<bob::learn::activation::ActivationRegistry> bob::learn::activation::ActivationRegistry::instance() {
static boost::shared_ptr<bob::learn::activation::ActivationRegistry> s_instance(new ActivationRegistry()); static boost::shared_ptr<bob::learn::activation::ActivationRegistry> s_instance(new ActivationRegistry());
...@@ -69,6 +70,20 @@ bob::learn::activation::activation_factory_t bob::learn::activation::ActivationR ...@@ -69,6 +70,20 @@ bob::learn::activation::activation_factory_t bob::learn::activation::ActivationR
auto it = s_id2factory.find(id); auto it = s_id2factory.find(id);
if (it == s_id2factory.end()) {
// try to convert the old "machine" name into the new "learn.activation" name
auto i = id.find("machine");
if (i != std::string::npos){
std::string tid = id;
tid.replace(i, 7, "learn.activation");
it = s_id2factory.find(tid);
if (it != s_id2factory.end()) {
bob::core::warn << "Using the old name of the activation function '" << id << "' is deprecated. Please use '" << tid << "' instead!";
}
}
}
if (it == s_id2factory.end()) { if (it == s_id2factory.end()) {
boost::format m("unregistered activation function: %s"); boost::format m("unregistered activation function: %s");
m % id; m % id;
...@@ -79,3 +94,29 @@ bob::learn::activation::activation_factory_t bob::learn::activation::ActivationR ...@@ -79,3 +94,29 @@ bob::learn::activation::activation_factory_t bob::learn::activation::ActivationR
} }
/**
* A generalized registration mechanism for all classes above
*/
template <typename T> struct register_activation {
static boost::shared_ptr<bob::learn::activation::Activation> factory (bob::io::base::HDF5File& f) {
auto retval = boost::make_shared<T>();
retval->load(f);
return retval;
}
register_activation() {
T obj;
bob::learn::activation::ActivationRegistry::instance()->registerActivation(obj.unique_identifier(), register_activation<T>::factory);
}
};
// register all extensions
static register_activation<bob::learn::activation::IdentityActivation> _identity_act_reg;
static register_activation<bob::learn::activation::LinearActivation> _linear_act_reg;
static register_activation<bob::learn::activation::HyperbolicTangentActivation> _tanh_act_reg;
static register_activation<bob::learn::activation::MultipliedHyperbolicTangentActivation> _multanh_act_reg;
static register_activation<bob::learn::activation::LogisticActivation> _logistic_act_reg;
...@@ -92,7 +92,7 @@ namespace bob { namespace learn { namespace activation { ...@@ -92,7 +92,7 @@ namespace bob { namespace learn { namespace activation {
virtual double f (double z) const { return z; } virtual double f (double z) const { return z; }
virtual double f_prime (double z) const { return 1.; } virtual double f_prime (double z) const { return 1.; }
virtual double f_prime_from_f (double a) const { return 1.; } virtual double f_prime_from_f (double a) const { return 1.; }
virtual std::string unique_identifier() const { return "bob.machine.Activation.Identity"; } virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Identity"; }
virtual std::string str() const { return "f(z) = z"; } virtual std::string str() const { return "f(z) = z"; }
}; };
...@@ -112,7 +112,7 @@ namespace bob { namespace learn { namespace activation { ...@@ -112,7 +112,7 @@ namespace bob { namespace learn { namespace activation {
double C() const { return m_C; } double C() const { return m_C; }
virtual void save(bob::io::base::HDF5File& f) const { Activation::save(f); f.set("C", m_C); } virtual void save(bob::io::base::HDF5File& f) const { Activation::save(f); f.set("C", m_C); }
virtual void load(bob::io::base::HDF5File& f) { m_C = f.read<double>("C"); } virtual void load(bob::io::base::HDF5File& f) { m_C = f.read<double>("C"); }
virtual std::string unique_identifier() const { return "bob.machine.Activation.Linear"; } virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Linear"; }
virtual std::string str() const { return (boost::format("f(z) = %.5e * z") % m_C).str(); } virtual std::string str() const { return (boost::format("f(z) = %.5e * z") % m_C).str(); }
private: // representation private: // representation
...@@ -131,7 +131,7 @@ namespace bob { namespace learn { namespace activation { ...@@ -131,7 +131,7 @@ namespace bob { namespace learn { namespace activation {
virtual ~HyperbolicTangentActivation() {} virtual ~HyperbolicTangentActivation() {}
virtual double f (double z) const { return std::tanh(z); } virtual double f (double z) const { return std::tanh(z); }
virtual double f_prime_from_f (double a) const { return (1. - (a*a)); } virtual double f_prime_from_f (double a) const { return (1. - (a*a)); }
virtual std::string unique_identifier() const { return "bob.machine.Activation.HyperbolicTangent"; } virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.HyperbolicTangent"; }
virtual std::string str() const { return "f(z) = tanh(z)"; } virtual std::string str() const { return "f(z) = tanh(z)"; }
}; };
...@@ -151,7 +151,7 @@ namespace bob { namespace learn { namespace activation { ...@@ -151,7 +151,7 @@ namespace bob { namespace learn { namespace activation {
double M() const { return m_M; } double M() const { return m_M; }
virtual void save(bob::io::base::HDF5File& f) const { Activation::save(f); f.set("C", m_C); f.set("M", m_C); } virtual void save(bob::io::base::HDF5File& f) const { Activation::save(f); f.set("C", m_C); f.set("M", m_C); }
virtual void load(bob::io::base::HDF5File& f) {m_C = f.read<double>("C"); m_M = f.read<double>("M"); } virtual void load(bob::io::base::HDF5File& f) {m_C = f.read<double>("C"); m_M = f.read<double>("M"); }
virtual std::string unique_identifier() const { return "bob.machine.Activation.MultipliedHyperbolicTangent"; } virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.MultipliedHyperbolicTangent"; }
virtual std::string str() const { return (boost::format("f(z) = %.5e * tanh(%.5e * z)") % m_C % m_M).str(); } virtual std::string str() const { return (boost::format("f(z) = %.5e * tanh(%.5e * z)") % m_C % m_M).str(); }
private: // representation private: // representation
...@@ -171,7 +171,7 @@ namespace bob { namespace learn { namespace activation { ...@@ -171,7 +171,7 @@ namespace bob { namespace learn { namespace activation {
virtual ~LogisticActivation() {} virtual ~LogisticActivation() {}
virtual double f (double z) const { return 1. / ( 1. + std::exp(-z) ); } virtual double f (double z) const { return 1. / ( 1. + std::exp(-z) ); }
virtual double f_prime_from_f (double a) const { return a * (1. - a); } virtual double f_prime_from_f (double a) const { return a * (1. - a); }
virtual std::string unique_identifier() const { return "bob.machine.Activation.Logistic"; } virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Logistic"; }
virtual std::string str() const { return "f(z) = 1./(1. + e^-z)"; } virtual std::string str() const { return "f(z) = 1./(1. + e^-z)"; }
}; };
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment