From 057e58d49ef8663b5ae52b08a67d8359957a144e Mon Sep 17 00:00:00 2001
From: Manuel Guenther <manuel.guenther@idiap.ch>
Date: Tue, 19 Aug 2014 15:25:47 +0200
Subject: [PATCH] Updated unique names of activation functions; registered all
 activation functions.

---
 .../activation/cpp/ActivationRegistry.cpp     | 41 +++++++++++++++++++
 .../include/bob.learn.activation/Activation.h | 10 ++---
 2 files changed, 46 insertions(+), 5 deletions(-)

diff --git a/bob/learn/activation/cpp/ActivationRegistry.cpp b/bob/learn/activation/cpp/ActivationRegistry.cpp
index 7a650fe..275524e 100644
--- a/bob/learn/activation/cpp/ActivationRegistry.cpp
+++ b/bob/learn/activation/cpp/ActivationRegistry.cpp
@@ -9,6 +9,7 @@
 
 #include <bob.learn.activation/Activation.h>
 #include <boost/make_shared.hpp>
+#include <bob.core/logging.h>
 
 boost::shared_ptr<bob::learn::activation::ActivationRegistry> bob::learn::activation::ActivationRegistry::instance() {
   static boost::shared_ptr<bob::learn::activation::ActivationRegistry> s_instance(new ActivationRegistry());
@@ -69,6 +70,20 @@ bob::learn::activation::activation_factory_t bob::learn::activation::ActivationR
 
   auto it = s_id2factory.find(id);
 
+  if (it == s_id2factory.end()) {
+    // try to convert the old "machine" name into the new "learn.activation" name
+    auto i = id.find("machine");
+    if (i != std::string::npos){
+      std::string tid = id;
+      tid.replace(i, 7, "learn.activation");
+
+      it = s_id2factory.find(tid);
+      if (it != s_id2factory.end()) {
+        bob::core::warn << "Using the old name of the activation function '" << id << "' is deprecated. Please use '" << tid << "' instead!";
+      }
+    }
+  }
+
   if (it == s_id2factory.end()) {
     boost::format m("unregistered activation function: %s");
     m % id;
@@ -79,3 +94,29 @@ bob::learn::activation::activation_factory_t bob::learn::activation::ActivationR
 
 }
 
+/**
+ * A generalized registration mechanism for all classes above
+ */
+template <typename T> struct register_activation {
+
+  static boost::shared_ptr<bob::learn::activation::Activation> factory (bob::io::base::HDF5File& f) {
+    auto retval = boost::make_shared<T>();
+    retval->load(f);
+    return retval;
+  }
+
+  register_activation() {
+    T obj;
+    bob::learn::activation::ActivationRegistry::instance()->registerActivation(obj.unique_identifier(), register_activation<T>::factory);
+  }
+
+};
+
+// register all extensions
+static register_activation<bob::learn::activation::IdentityActivation> _identity_act_reg;
+static register_activation<bob::learn::activation::LinearActivation> _linear_act_reg;
+static register_activation<bob::learn::activation::HyperbolicTangentActivation> _tanh_act_reg;
+static register_activation<bob::learn::activation::MultipliedHyperbolicTangentActivation> _multanh_act_reg;
+static register_activation<bob::learn::activation::LogisticActivation> _logistic_act_reg;
+
+
diff --git a/bob/learn/activation/include/bob.learn.activation/Activation.h b/bob/learn/activation/include/bob.learn.activation/Activation.h
index 53142e7..643a1ce 100644
--- a/bob/learn/activation/include/bob.learn.activation/Activation.h
+++ b/bob/learn/activation/include/bob.learn.activation/Activation.h
@@ -92,7 +92,7 @@ namespace bob { namespace learn { namespace activation {
       virtual double f (double z) const { return z; }
       virtual double f_prime (double z) const { return 1.; }
       virtual double f_prime_from_f (double a) const { return 1.; }
-      virtual std::string unique_identifier() const { return "bob.machine.Activation.Identity"; }
+      virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Identity"; }
       virtual std::string str() const { return "f(z) = z"; }
 
   };
@@ -112,7 +112,7 @@ namespace bob { namespace learn { namespace activation {
       double C() const { return m_C; }
       virtual void save(bob::io::base::HDF5File& f) const { Activation::save(f); f.set("C", m_C); }
       virtual void load(bob::io::base::HDF5File& f) { m_C = f.read<double>("C"); }
-      virtual std::string unique_identifier() const { return "bob.machine.Activation.Linear"; }
+      virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Linear"; }
       virtual std::string str() const { return (boost::format("f(z) = %.5e * z") % m_C).str(); }
 
     private: // representation
@@ -131,7 +131,7 @@ namespace bob { namespace learn { namespace activation {
       virtual ~HyperbolicTangentActivation() {}
       virtual double f (double z) const { return std::tanh(z); }
       virtual double f_prime_from_f (double a) const { return (1. - (a*a)); }
-      virtual std::string unique_identifier() const { return "bob.machine.Activation.HyperbolicTangent"; }
+      virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.HyperbolicTangent"; }
       virtual std::string str() const { return "f(z) = tanh(z)"; }
 
   };
@@ -151,7 +151,7 @@ namespace bob { namespace learn { namespace activation {
       double M() const { return m_M; }
       virtual void save(bob::io::base::HDF5File& f) const { Activation::save(f); f.set("C", m_C); f.set("M", m_C); }
       virtual void load(bob::io::base::HDF5File& f) {m_C = f.read<double>("C"); m_M = f.read<double>("M"); }
-      virtual std::string unique_identifier() const { return "bob.machine.Activation.MultipliedHyperbolicTangent"; }
+      virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.MultipliedHyperbolicTangent"; }
       virtual std::string str() const { return (boost::format("f(z) = %.5e * tanh(%.5e * z)") % m_C % m_M).str(); }
 
     private: // representation
@@ -171,7 +171,7 @@ namespace bob { namespace learn { namespace activation {
       virtual ~LogisticActivation() {}
       virtual double f (double z) const { return 1. / ( 1. + std::exp(-z) ); }
       virtual double f_prime_from_f (double a) const { return a * (1. - a); }
-      virtual std::string unique_identifier() const { return "bob.machine.Activation.Logistic"; }
+      virtual std::string unique_identifier() const { return "bob.learn.activation.Activation.Logistic"; }
       virtual std::string str() const { return "f(z) = 1./(1. + e^-z)"; }
 
   };
-- 
GitLab