diff --git a/bob/learn/activation/cpp/ActivationRegistry.cpp b/bob/learn/activation/cpp/ActivationRegistry.cpp index 5fcbe6e91658fdd6465a68e59d4d4f0e99a1296d..7a650fe642075067a6dd1b389aa37939aa2b2fd7 100644 --- a/bob/learn/activation/cpp/ActivationRegistry.cpp +++ b/bob/learn/activation/cpp/ActivationRegistry.cpp @@ -8,12 +8,35 @@ */ #include <bob.learn.activation/Activation.h> +#include <boost/make_shared.hpp> boost::shared_ptr<bob::learn::activation::ActivationRegistry> bob::learn::activation::ActivationRegistry::instance() { static boost::shared_ptr<bob::learn::activation::ActivationRegistry> s_instance(new ActivationRegistry()); return s_instance; } +boost::shared_ptr<bob::learn::activation::Activation> bob::learn::activation::load_activation(bob::io::base::HDF5File& f) { + auto make = ActivationRegistry::instance()->find(f.read<std::string>("id")); + return make(f); +} + +boost::shared_ptr<bob::learn::activation::Activation> bob::learn::activation::make_deprecated_activation(uint32_t e) { + switch(e) { + case 0: + return boost::make_shared<bob::learn::activation::IdentityActivation>(); + break; + case 1: + return boost::make_shared<bob::learn::activation::HyperbolicTangentActivation>(); + break; + case 2: + return boost::make_shared<bob::learn::activation::LogisticActivation>(); + break; + default: + throw std::runtime_error("unsupported (deprecated) activation read from HDF5 file - not any of 0 (linear), 1 (tanh) or 2 (logistic)"); + } +} + + void bob::learn::activation::ActivationRegistry::deregisterFactory(const std::string& id) { s_id2factory.erase(id); } @@ -35,7 +58,6 @@ void bob::learn::activation::ActivationRegistry::registerActivation(const std::s //replacing with the same factory may be the result of multiple python //modules being loaded. } - } bool bob::learn::activation::ActivationRegistry::isRegistered(const std::string& id) {