From 9331cef2de5c397c7394156a39295eea2f8e8bc3 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 28 Jan 2015 11:12:52 +0100 Subject: [PATCH] Binded JFABase --- .../misc/include/bob.learn.misc/JFAMachine.h | 924 +----------------- bob/learn/misc/jfa_base.cpp | 568 +++++++++++ bob/learn/misc/main.cpp | 1 + bob/learn/misc/main.h | 13 + 4 files changed, 596 insertions(+), 910 deletions(-) create mode 100644 bob/learn/misc/jfa_base.cpp diff --git a/bob/learn/misc/include/bob.learn.misc/JFAMachine.h b/bob/learn/misc/include/bob.learn.misc/JFAMachine.h index a1fc17c..48d1f4a 100644 --- a/bob/learn/misc/include/bob.learn.misc/JFAMachine.h +++ b/bob/learn/misc/include/bob.learn.misc/JFAMachine.h @@ -1,18 +1,19 @@ /** - * @date Sat Jul 23 21:41:15 2011 +0200 + * @date Tue Jan 27 16:47:00 2015 +0200 * @author Laurent El Shafey <Laurent.El-Shafey@idiap.ch> + * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch> * * @brief A base class for Joint Factor Analysis-like machines * * Copyright (C) Idiap Research Institute, Martigny, Switzerland */ -#ifndef BOB_LEARN_MISC_FABASE_H -#define BOB_LEARN_MISC_FABASE_H +#ifndef BOB_LEARN_MISC_JFAMACHINE_H +#define BOB_LEARN_MISC_JFAMACHINE_H #include <stdexcept> -#include <bob.learn.misc/Machine.h> +#include <bob.learn.misc/JFABase.h> #include <bob.learn.misc/GMMMachine.h> #include <bob.learn.misc/LinearScoring.h> @@ -21,701 +22,6 @@ namespace bob { namespace learn { namespace misc { -/** - * @brief A FA Base class which contains U, V and D matrices - * TODO: add a reference to the journal articles - */ -class FABase -{ - public: - /** - * @brief Default constructor. Builds an otherwise invalid 0 x 0 FABase - * The Universal Background Model and the matrices U, V and diag(d) are - * not initialized. - */ - FABase(); - - /** - * @brief Constructor. Builds a new FABase. - * The Universal Background Model and the matrices U, V and diag(d) are - * not initialized. - * - * @param ubm The Universal Background Model - * @param ru size of U (CD x ru) - * @param rv size of U (CD x rv) - * @warning ru and rv SHOULD BE >= 1. Just set U/V/D to zero if you want - * to ignore one subspace. This is the case for ISV. - */ - FABase(const boost::shared_ptr<bob::learn::misc::GMMMachine> ubm, const size_t ru=1, const size_t rv=1); - - /** - * @brief Copy constructor - */ - FABase(const FABase& other); - - /** - * @brief Just to virtualise the destructor - */ - virtual ~FABase(); - - /** - * @brief Assigns from a different JFA machine - */ - FABase& operator=(const FABase &other); - - /** - * @brief Equal to - */ - bool operator==(const FABase& b) const; - - /** - * @brief Not equal to - */ - bool operator!=(const FABase& b) const; - - /** - * @brief Similar to - */ - bool is_similar_to(const FABase& b, const double r_epsilon=1e-5, - const double a_epsilon=1e-8) const; - - /** - * @brief Returns the UBM - */ - const boost::shared_ptr<bob::learn::misc::GMMMachine> getUbm() const - { return m_ubm; } - - /** - * @brief Returns the U matrix - */ - const blitz::Array<double,2>& getU() const - { return m_U; } - - /** - * @brief Returns the V matrix - */ - const blitz::Array<double,2>& getV() const - { return m_V; } - - /** - * @brief Returns the diagonal matrix diag(d) (as a 1D vector) - */ - const blitz::Array<double,1>& getD() const - { return m_d; } - - /** - * @brief Returns the UBM mean supervector (as a 1D vector) - */ - const blitz::Array<double,1>& getUbmMean() const - { return m_cache_mean; } - - /** - * @brief Returns the UBM variance supervector (as a 1D vector) - */ - const blitz::Array<double,1>& getUbmVariance() const - { return m_cache_sigma; } - - /** - * @brief Returns the number of Gaussian components C - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimC() const - { if(!m_ubm) throw std::runtime_error("No UBM was set in the JFA machine."); - return m_ubm->getNGaussians(); } - - /** - * @brief Returns the feature dimensionality D - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimD() const - { if(!m_ubm) throw std::runtime_error("No UBM was set in the JFA machine."); - return m_ubm->getNInputs(); } - - /** - * @brief Returns the supervector length CD - * (CxD: Number of Gaussian components by the feature dimensionality) - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimCD() const - { if(!m_ubm) throw std::runtime_error("No UBM was set in the JFA machine."); - return m_ubm->getNInputs()*m_ubm->getNGaussians(); } - - /** - * @brief Returns the size/rank ru of the U matrix - */ - const size_t getDimRu() const - { return m_ru; } - - /** - * @brief Returns the size/rank rv of the V matrix - */ - const size_t getDimRv() const - { return m_rv; } - - /** - * @brief Resets the dimensionality of the subspace U and V - * U and V are hence uninitialized. - */ - void resize(const size_t ru, const size_t rv); - - /** - * @brief Resets the dimensionality of the subspace U and V, - * assuming that no UBM has yet been set - * U and V are hence uninitialized. - */ - void resize(const size_t ru, const size_t rv, const size_t cd); - - /** - * @brief Returns the U matrix in order to update it - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - blitz::Array<double,2>& updateU() - { return m_U; } - - /** - * @brief Returns the V matrix in order to update it - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - blitz::Array<double,2>& updateV() - { return m_V; } - - /** - * @brief Returns the diagonal matrix diag(d) (as a 1D vector) in order - * to update it - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - blitz::Array<double,1>& updateD() - { return m_d; } - - - /** - * @brief Sets (the mean supervector of) the Universal Background Model - * U, V and d are uninitialized in case of dimensions update (C or D) - */ - void setUbm(const boost::shared_ptr<bob::learn::misc::GMMMachine> ubm); - - /** - * @brief Sets the U matrix - */ - void setU(const blitz::Array<double,2>& U); - - /** - * @brief Sets the V matrix - */ - void setV(const blitz::Array<double,2>& V); - - /** - * @brief Sets the diagonal matrix diag(d) - * (a 1D vector is expected as an argument) - */ - void setD(const blitz::Array<double,1>& d); - - - /** - * @brief Estimates x from the GMM statistics considering the LPT - * assumption, that is the latent session variable x is approximated - * using the UBM - */ - void estimateX(const bob::learn::misc::GMMStats& gmm_stats, blitz::Array<double,1>& x) const; - - /** - * @brief Compute and put U^{T}.Sigma^{-1} matrix in cache - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - void updateCacheUbmUVD(); - - - private: - /** - * @brief Update cache arrays/variables - */ - void updateCache(); - /** - * @brief Put GMM mean/variance supervector in cache - */ - void updateCacheUbm(); - /** - * @brief Resize working arrays - */ - void resizeTmp(); - /** - * @brief Computes (Id + U^T.Sigma^-1.U.N_{i,h}.U)^-1 = - * (Id + sum_{c=1..C} N_{i,h}.U_{c}^T.Sigma_{c}^-1.U_{c})^-1 - */ - void computeIdPlusUSProdInv(const bob::learn::misc::GMMStats& gmm_stats, - blitz::Array<double,2>& out) const; - /** - * @brief Computes Fn_x = sum_{sessions h}(N*(o - m)) - * (Normalised first order statistics) - */ - void computeFn_x(const bob::learn::misc::GMMStats& gmm_stats, - blitz::Array<double,1>& out) const; - /** - * @brief Estimates the value of x from the passed arguments - * (IdPlusUSProdInv and Fn_x), considering the LPT assumption - */ - void estimateX(const blitz::Array<double,2>& IdPlusUSProdInv, - const blitz::Array<double,1>& Fn_x, blitz::Array<double,1>& x) const; - - - // UBM - boost::shared_ptr<bob::learn::misc::GMMMachine> m_ubm; - - // dimensionality - size_t m_ru; // size of U (CD x ru) - size_t m_rv; // size of V (CD x rv) - - // U, V, D matrices - // D is assumed to be diagonal, and only the diagonal is stored - blitz::Array<double,2> m_U; - blitz::Array<double,2> m_V; - blitz::Array<double,1> m_d; - - // Vectors/Matrices precomputed in cache - blitz::Array<double,1> m_cache_mean; - blitz::Array<double,1> m_cache_sigma; - blitz::Array<double,2> m_cache_UtSigmaInv; - - mutable blitz::Array<double,2> m_tmp_IdPlusUSProdInv; - mutable blitz::Array<double,1> m_tmp_Fn_x; - mutable blitz::Array<double,1> m_tmp_ru; - mutable blitz::Array<double,2> m_tmp_ruD; - mutable blitz::Array<double,2> m_tmp_ruru; -}; - - -/** - * @brief A JFA Base class which contains U, V and D matrices - * TODO: add a reference to the journal articles - */ -class JFABase -{ - public: - /** - * @brief Default constructor. Builds a 1 x 1 JFABase - * The Universal Background Model and the matrices U, V and diag(d) are - * not initialized. - */ - JFABase(); - - /** - * @brief Constructor. Builds a new JFABase. - * The Universal Background Model and the matrices U, V and diag(d) are - * not initialized. - * - * @param ubm The Universal Background Model - * @param ru size of U (CD x ru) - * @param rv size of U (CD x rv) - * @warning ru and rv SHOULD BE >= 1. - */ - JFABase(const boost::shared_ptr<bob::learn::misc::GMMMachine> ubm, const size_t ru=1, const size_t rv=1); - - /** - * @brief Copy constructor - */ - JFABase(const JFABase& other); - - /** - * @deprecated Starts a new JFAMachine from an existing Configuration object. - */ - JFABase(bob::io::base::HDF5File& config); - - /** - * @brief Just to virtualise the destructor - */ - virtual ~JFABase(); - - /** - * @brief Assigns from a different JFA machine - */ - JFABase& operator=(const JFABase &other); - - /** - * @brief Equal to - */ - bool operator==(const JFABase& b) const - { return m_base.operator==(b.m_base); } - - /** - * @brief Not equal to - */ - bool operator!=(const JFABase& b) const - { return m_base.operator!=(b.m_base); } - - /** - * @brief Similar to - */ - bool is_similar_to(const JFABase& b, const double r_epsilon=1e-5, - const double a_epsilon=1e-8) const - { return m_base.is_similar_to(b.m_base, r_epsilon, a_epsilon); } - - /** - * @brief Saves model to an HDF5 file - */ - void save(bob::io::base::HDF5File& config) const; - - /** - * @brief Loads data from an existing configuration object. Resets - * the current state. - */ - void load(bob::io::base::HDF5File& config); - - /** - * @brief Returns the UBM - */ - const boost::shared_ptr<bob::learn::misc::GMMMachine> getUbm() const - { return m_base.getUbm(); } - - /** - * @brief Returns the U matrix - */ - const blitz::Array<double,2>& getU() const - { return m_base.getU(); } - - /** - * @brief Returns the V matrix - */ - const blitz::Array<double,2>& getV() const - { return m_base.getV(); } - - /** - * @brief Returns the diagonal matrix diag(d) (as a 1D vector) - */ - const blitz::Array<double,1>& getD() const - { return m_base.getD(); } - - /** - * @brief Returns the number of Gaussian components C - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimC() const - { return m_base.getDimC(); } - - /** - * @brief Returns the feature dimensionality D - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimD() const - { return m_base.getDimD(); } - - /** - * @brief Returns the supervector length CD - * (CxD: Number of Gaussian components by the feature dimensionality) - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimCD() const - { return m_base.getDimCD(); } - - /** - * @brief Returns the size/rank ru of the U matrix - */ - const size_t getDimRu() const - { return m_base.getDimRu(); } - - /** - * @brief Returns the size/rank rv of the V matrix - */ - const size_t getDimRv() const - { return m_base.getDimRv(); } - - /** - * @brief Resets the dimensionality of the subspace U and V - * U and V are hence uninitialized. - */ - void resize(const size_t ru, const size_t rv) - { m_base.resize(ru, rv); } - - /** - * @brief Returns the U matrix in order to update it - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - blitz::Array<double,2>& updateU() - { return m_base.updateU(); } - - /** - * @brief Returns the V matrix in order to update it - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - blitz::Array<double,2>& updateV() - { return m_base.updateV(); } - - /** - * @brief Returns the diagonal matrix diag(d) (as a 1D vector) in order - * to update it - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - blitz::Array<double,1>& updateD() - { return m_base.updateD(); } - - - /** - * @brief Sets (the mean supervector of) the Universal Background Model - * U, V and d are uninitialized in case of dimensions update (C or D) - */ - void setUbm(const boost::shared_ptr<bob::learn::misc::GMMMachine> ubm) - { m_base.setUbm(ubm); } - - /** - * @brief Sets the U matrix - */ - void setU(const blitz::Array<double,2>& U) - { m_base.setU(U); } - - /** - * @brief Sets the V matrix - */ - void setV(const blitz::Array<double,2>& V) - { m_base.setV(V); } - - /** - * @brief Sets the diagonal matrix diag(d) - * (a 1D vector is expected as an argument) - */ - void setD(const blitz::Array<double,1>& d) - { m_base.setD(d); } - - /** - * @brief Estimates x from the GMM statistics considering the LPT - * assumption, that is the latent session variable x is approximated - * using the UBM - */ - void estimateX(const bob::learn::misc::GMMStats& gmm_stats, blitz::Array<double,1>& x) const - { m_base.estimateX(gmm_stats, x); } - - /** - * @brief Precompute (put U^{T}.Sigma^{-1} matrix in cache) - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - void precompute() - { m_base.updateCacheUbmUVD(); } - - /** - * @brief Returns the FABase member - */ - const bob::learn::misc::FABase& getBase() const - { return m_base; } - - - private: - // FABase - bob::learn::misc::FABase m_base; -}; - - -/** - * @brief An ISV Base class which contains U and D matrices - * TODO: add a reference to the journal articles - */ -class ISVBase -{ - public: - /** - * @brief Default constructor. Builds an otherwise invalid 0 x 0 ISVBase - * The Universal Background Model and the matrices U, V and diag(d) are - * not initialized. - */ - ISVBase(); - - /** - * @brief Constructor. Builds a new ISVBase. - * The Universal Background Model and the matrices U, V and diag(d) are - * not initialized. - * - * @param ubm The Universal Background Model - * @param ru size of U (CD x ru) - * @warning ru SHOULD BE >= 1. - */ - ISVBase(const boost::shared_ptr<bob::learn::misc::GMMMachine> ubm, const size_t ru=1); - - /** - * @brief Copy constructor - */ - ISVBase(const ISVBase& other); - - /** - * @deprecated Starts a new JFAMachine from an existing Configuration object. - */ - ISVBase(bob::io::base::HDF5File& config); - - /** - * @brief Just to virtualise the destructor - */ - virtual ~ISVBase(); - - /** - * @brief Assigns from a different JFA machine - */ - ISVBase& operator=(const ISVBase &other); - - /** - * @brief Equal to - */ - bool operator==(const ISVBase& b) const - { return m_base.operator==(b.m_base); } - - /** - * @brief Not equal to - */ - bool operator!=(const ISVBase& b) const - { return m_base.operator!=(b.m_base); } - - /** - * @brief Similar to - */ - bool is_similar_to(const ISVBase& b, const double r_epsilon=1e-5, - const double a_epsilon=1e-8) const - { return m_base.is_similar_to(b.m_base, r_epsilon, a_epsilon); } - - /** - * @brief Saves machine to an HDF5 file - */ - void save(bob::io::base::HDF5File& config) const; - - /** - * @brief Loads data from an existing configuration object. Resets - * the current state. - */ - void load(bob::io::base::HDF5File& config); - - /** - * @brief Returns the UBM - */ - const boost::shared_ptr<bob::learn::misc::GMMMachine> getUbm() const - { return m_base.getUbm(); } - - /** - * @brief Returns the U matrix - */ - const blitz::Array<double,2>& getU() const - { return m_base.getU(); } - - /** - * @brief Returns the diagonal matrix diag(d) (as a 1D vector) - */ - const blitz::Array<double,1>& getD() const - { return m_base.getD(); } - - /** - * @brief Returns the number of Gaussian components C - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimC() const - { return m_base.getDimC(); } - - /** - * @brief Returns the feature dimensionality D - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimD() const - { return m_base.getDimD(); } - - /** - * @brief Returns the supervector length CD - * (CxD: Number of Gaussian components by the feature dimensionality) - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimCD() const - { return m_base.getDimCD(); } - - /** - * @brief Returns the size/rank ru of the U matrix - */ - const size_t getDimRu() const - { return m_base.getDimRu(); } - - /** - * @brief Resets the dimensionality of the subspace U and V - * U and V are hence uninitialized. - */ - void resize(const size_t ru) - { m_base.resize(ru, 1); - blitz::Array<double,2>& V = m_base.updateV(); - V = 0; - } - - /** - * @brief Returns the U matrix in order to update it - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - blitz::Array<double,2>& updateU() - { return m_base.updateU(); } - - /** - * @brief Returns the diagonal matrix diag(d) (as a 1D vector) in order - * to update it - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - blitz::Array<double,1>& updateD() - { return m_base.updateD(); } - - - /** - * @brief Sets (the mean supervector of) the Universal Background Model - * U, V and d are uninitialized in case of dimensions update (C or D) - */ - void setUbm(const boost::shared_ptr<bob::learn::misc::GMMMachine> ubm) - { m_base.setUbm(ubm); } - - /** - * @brief Sets the U matrix - */ - void setU(const blitz::Array<double,2>& U) - { m_base.setU(U); } - - /** - * @brief Sets the diagonal matrix diag(d) - * (a 1D vector is expected as an argument) - */ - void setD(const blitz::Array<double,1>& d) - { m_base.setD(d); } - - /** - * @brief Estimates x from the GMM statistics considering the LPT - * assumption, that is the latent session variable x is approximated - * using the UBM - */ - void estimateX(const bob::learn::misc::GMMStats& gmm_stats, blitz::Array<double,1>& x) const - { m_base.estimateX(gmm_stats, x); } - - /** - * @brief Precompute (put U^{T}.Sigma^{-1} matrix in cache) - * @warning Should only be used by the trainer for efficiency reason, - * or for testing purpose. - */ - void precompute() - { m_base.updateCacheUbmUVD(); } - - /** - * @brief Returns the FABase member - */ - const bob::learn::misc::FABase& getBase() const - { return m_base; } - - - private: - // FABase - bob::learn::misc::FABase m_base; -}; - /** * @brief A JFAMachine which is associated to a JFABase that contains @@ -723,7 +29,7 @@ class ISVBase * (latent variables y and z) * TODO: add a reference to the journal articles */ -class JFAMachine: public Machine<bob::learn::misc::GMMStats, double> +class JFAMachine { public: /** @@ -792,16 +98,16 @@ class JFAMachine: public Machine<bob::learn::misc::GMMStats, double> * @warning An exception is thrown if no Universal Background Model has * been set yet. */ - const size_t getDimC() const - { return m_jfa_base->getDimC(); } + const size_t getNGaussians() const + { return m_jfa_base->getNGaussians(); } /** * @brief Returns the feature dimensionality D * @warning An exception is thrown if no Universal Background Model has * been set yet. */ - const size_t getDimD() const - { return m_jfa_base->getDimD(); } + const size_t getNInputs() const + { return m_jfa_base->getNInputs(); } /** * @brief Returns the supervector length CD @@ -809,8 +115,8 @@ class JFAMachine: public Machine<bob::learn::misc::GMMStats, double> * @warning An exception is thrown if no Universal Background Model has * been set yet. */ - const size_t getDimCD() const - { return m_jfa_base->getDimCD(); } + const size_t getSupervectorLength() const + { return m_jfa_base->getSupervectorLength(); } /** * @brief Returns the size/rank ru of the U matrix @@ -914,7 +220,7 @@ class JFAMachine: public Machine<bob::learn::misc::GMMStats, double> */ void forward_(const bob::learn::misc::GMMStats& input, double& score) const; - protected: + private: /** * @brief Resize latent variable according to the JFABase */ @@ -943,208 +249,6 @@ class JFAMachine: public Machine<bob::learn::misc::GMMStats, double> mutable blitz::Array<double,1> m_tmp_Ux; }; -/** - * @brief A ISVMachine which is associated to a ISVBase that contains - * U D matrices. - * TODO: add a reference to the journal articles - */ -class ISVMachine: public Machine<bob::learn::misc::GMMStats, double> -{ - public: - /** - * @brief Default constructor. Builds an otherwise invalid 0 x 0 ISVMachine - * The Universal Background Model and the matrices U, V and diag(d) are - * not initialized. - */ - ISVMachine(); - - /** - * @brief Constructor. Builds a new ISVMachine. - * - * @param isv_base The ISVBase associated with this machine - */ - ISVMachine(const boost::shared_ptr<bob::learn::misc::ISVBase> isv_base); - - /** - * @brief Copy constructor - */ - ISVMachine(const ISVMachine& other); - - /** - * @brief Starts a new ISVMachine from an existing Configuration object. - */ - ISVMachine(bob::io::base::HDF5File& config); - - /** - * @brief Just to virtualise the destructor - */ - virtual ~ISVMachine(); - - /** - * @brief Assigns from a different ISV machine - */ - ISVMachine& operator=(const ISVMachine &other); - - /** - * @brief Equal to - */ - bool operator==(const ISVMachine& b) const; - - /** - * @brief Not equal to - */ - bool operator!=(const ISVMachine& b) const; - - /** - * @brief Similar to - */ - bool is_similar_to(const ISVMachine& b, const double r_epsilon=1e-5, - const double a_epsilon=1e-8) const; - - /** - * @brief Saves machine to an HDF5 file - */ - void save(bob::io::base::HDF5File& config) const; - - /** - * @brief Loads data from an existing configuration object. Resets - * the current state. - */ - void load(bob::io::base::HDF5File& config); - - - /** - * @brief Returns the number of Gaussian components C - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimC() const - { return m_isv_base->getDimC(); } - - /** - * @brief Returns the feature dimensionality D - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimD() const - { return m_isv_base->getDimD(); } - - /** - * @brief Returns the supervector length CD - * (CxD: Number of Gaussian components by the feature dimensionality) - * @warning An exception is thrown if no Universal Background Model has - * been set yet. - */ - const size_t getDimCD() const - { return m_isv_base->getDimCD(); } - - /** - * @brief Returns the size/rank ru of the U matrix - */ - const size_t getDimRu() const - { return m_isv_base->getDimRu(); } - - /** - * @brief Returns the x session factor - */ - const blitz::Array<double,1>& getX() const - { return m_cache_x; } - - /** - * @brief Returns the z speaker factor - */ - const blitz::Array<double,1>& getZ() const - { return m_z; } - - /** - * @brief Returns the z speaker factors in order to update it - */ - blitz::Array<double,1>& updateZ() - { return m_z; } - - /** - * @brief Returns the V matrix - */ - void setZ(const blitz::Array<double,1>& z); - - /** - * @brief Returns the ISVBase - */ - const boost::shared_ptr<bob::learn::misc::ISVBase> getISVBase() const - { return m_isv_base; } - - /** - * @brief Sets the ISVBase - */ - void setISVBase(const boost::shared_ptr<bob::learn::misc::ISVBase> isv_base); - - - /** - * @brief Estimates x from the GMM statistics considering the LPT - * assumption, that is the latent session variable x is approximated - * using the UBM - */ - void estimateX(const bob::learn::misc::GMMStats& gmm_stats, blitz::Array<double,1>& x) const - { m_isv_base->estimateX(gmm_stats, x); } - /** - * @brief Estimates Ux from the GMM statistics considering the LPT - * assumption, that is the latent session variable x is approximated - * using the UBM - */ - void estimateUx(const bob::learn::misc::GMMStats& gmm_stats, blitz::Array<double,1>& Ux); - - /** - * @brief Execute the machine - * - * @param input input data used by the machine - * @param score value computed by the machine - * @warning Inputs are checked - */ - void forward(const bob::learn::misc::GMMStats& input, double& score) const; - /** - * @brief Computes a score for the given UBM statistics and given the - * Ux vector - */ - void forward(const bob::learn::misc::GMMStats& gmm_stats, - const blitz::Array<double,1>& Ux, double& score) const; - - /** - * @brief Execute the machine - * - * @param input input data used by the machine - * @param score value computed by the machine - * @warning Inputs are NOT checked - */ - void forward_(const bob::learn::misc::GMMStats& input, double& score) const; - - protected: - /** - * @brief Resize latent variable according to the ISVBase - */ - void resize(); - /** - * @ Update cache - */ - void updateCache(); - /** - * @brief Resize working arrays - */ - void resizeTmp(); - - // UBM - boost::shared_ptr<bob::learn::misc::ISVBase> m_isv_base; - - // y and z vectors/factors learned during the enrolment procedure - blitz::Array<double,1> m_z; - - // cache - blitz::Array<double,1> m_cache_mDz; - mutable blitz::Array<double,1> m_cache_x; - - // x vector/factor in cache when computing scores - mutable blitz::Array<double,1> m_tmp_Ux; -}; - } } } // namespaces -#endif // BOB_LEARN_MISC_FABASE_H +#endif // BOB_LEARN_MISC_JFAMACHINE_H diff --git a/bob/learn/misc/jfa_base.cpp b/bob/learn/misc/jfa_base.cpp new file mode 100644 index 0000000..7fa60b5 --- /dev/null +++ b/bob/learn/misc/jfa_base.cpp @@ -0,0 +1,568 @@ +/** + * @date Tue Jan 27 17:03:15 2015 +0200 + * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + * + * @brief Python API for bob::learn::em + * + * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland + */ + +#include "main.h" + +/******************************************************************/ +/************ Constructor Section *********************************/ +/******************************************************************/ + +static auto JFABase_doc = bob::extension::ClassDoc( + BOB_EXT_MODULE_PREFIX ".JFABase", + + "Constructor. Builds a new JFABase. " + "TODO: add a reference to the journal articles", + "" +).add_constructor( + bob::extension::FunctionDoc( + "__init__", + "Creates a FABase", + "", + true + ) + .add_prototype("gmm,ru,rv","") + .add_prototype("other","") + .add_prototype("hdf5","") + .add_prototype("","") + + .add_parameter("gmm", ":py:class:`bob.learn.misc.GMMMachine`", "The Universal Background Model.") + .add_parameter("ru", "int", "Size of U (Within client variation matrix). In the end the U matrix will have (number_of_gaussians * feature_dimension x ru)") + .add_parameter("rv", "int", "Size of V (Between client variation matrix). In the end the U matrix will have (number_of_gaussians * feature_dimension x rv)") + .add_parameter("other", ":py:class:`bob.learn.misc.JFABase`", "A JFABase object to be copied.") + .add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for reading") + +); + + +static int PyBobLearnMiscJFABase_init_copy(PyBobLearnMiscJFABaseObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = JFABase_doc.kwlist(1); + PyBobLearnMiscJFABaseObject* o; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscJFABase_Type, &o)){ + JFABase_doc.print_usage(); + return -1; + } + + self->cxx.reset(new bob::learn::misc::JFABase(*o->cxx)); + return 0; +} + + +static int PyBobLearnMiscJFABase_init_hdf5(PyBobLearnMiscJFABaseObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = JFABase_doc.kwlist(2); + + PyBobIoHDF5FileObject* config = 0; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBobIoHDF5File_Converter, &config)){ + JFABase_doc.print_usage(); + return -1; + } + + self->cxx.reset(new bob::learn::misc::JFABase(*(config->f))); + + return 0; +} + + +static int PyBobLearnMiscJFABase_init_ubm(PyBobLearnMiscJFABaseObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = JFABase_doc.kwlist(0); + + PyBobLearnMiscGMMMachineObject* ubm; + int ru = 1; + int rv = 1; + + //Here we have to select which keyword argument to read + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!ii", kwlist, &PyBobLearnMiscGMMMachine_Type, &ubm, + &ru, &rv)){ + JFABase_doc.print_usage(); + return -1; + } + + self->cxx.reset(new bob::learn::misc::JFABase(ubm->cxx, ru, rv)); + return 0; +} + + +static int PyBobLearnMiscJFABase_init(PyBobLearnMiscJFABaseObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + // get the number of command line arguments + int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0); + + switch (nargs) { + + case 1:{ + //Reading the input argument + PyObject* arg = 0; + if (PyTuple_Size(args)) + arg = PyTuple_GET_ITEM(args, 0); + else { + PyObject* tmp = PyDict_Values(kwargs); + auto tmp_ = make_safe(tmp); + arg = PyList_GET_ITEM(tmp, 0); + } + + // If the constructor input is Gaussian object + if (PyBobLearnMiscJFABase_Check(arg)) + return PyBobLearnMiscJFABase_init_copy(self, args, kwargs); + // If the constructor input is a HDF5 + else if (PyBobIoHDF5File_Check(arg)) + return PyBobLearnMiscJFABase_init_hdf5(self, args, kwargs); + } + case 3: + return PyBobLearnMiscJFABase_init_ubm(self, args, kwargs); + default: + PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 1 or 3 arguments, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs); + JFABase_doc.print_usage(); + return -1; + } + BOB_CATCH_MEMBER("cannot create JFABase", 0) + return 0; +} + + + +static void PyBobLearnMiscJFABase_delete(PyBobLearnMiscJFABaseObject* self) { + self->cxx.reset(); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static PyObject* PyBobLearnMiscJFABase_RichCompare(PyBobLearnMiscJFABaseObject* self, PyObject* other, int op) { + BOB_TRY + + if (!PyBobLearnMiscJFABase_Check(other)) { + PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name); + return 0; + } + auto other_ = reinterpret_cast<PyBobLearnMiscJFABaseObject*>(other); + switch (op) { + case Py_EQ: + if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE; + case Py_NE: + if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE; + default: + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + BOB_CATCH_MEMBER("cannot compare JFABase objects", 0) +} + +int PyBobLearnMiscJFABase_Check(PyObject* o) { + return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscJFABase_Type)); +} + + +/******************************************************************/ +/************ Variables Section ***********************************/ +/******************************************************************/ + +/***** shape *****/ +static auto shape = bob::extension::VariableDoc( + "shape", + "(int,int, int, int)", + "A tuple that represents the number of gaussians, dimensionality of each Gaussian, dimensionality of the rU (within client variability matrix) and dimensionality of the rV (between client variability matrix) ``(#Gaussians, #Inputs, #rU, #rV)``.", + "" +); +PyObject* PyBobLearnMiscJFABase_getShape(PyBobLearnMiscJFABaseObject* self, void*) { + BOB_TRY + return Py_BuildValue("(i,i,i,i)", self->cxx->getNGaussians(), self->cxx->getNInputs(), self->cxx->getDimRu(), self->cxx->getDimRv()); + BOB_CATCH_MEMBER("shape could not be read", 0) +} + +/***** supervector_length *****/ +static auto supervector_length = bob::extension::VariableDoc( + "supervector_length", + "int", + + "Returns the supervector length." + "NGaussians x NInputs: Number of Gaussian components by the feature dimensionality", + + "@warning An exception is thrown if no Universal Background Model has been set yet." +); +PyObject* PyBobLearnMiscJFABase_getSupervectorLength(PyBobLearnMiscJFABaseObject* self, void*) { + BOB_TRY + return Py_BuildValue("i", self->cxx->getSupervectorLength()); + BOB_CATCH_MEMBER("supervector_length could not be read", 0) +} + + +/***** u *****/ +static auto U = bob::extension::VariableDoc( + "u", + "array_like <float, 2D>", + "Returns the U matrix (within client variability matrix)", + "" +); +PyObject* PyBobLearnMiscJFABase_getU(PyBobLearnMiscJFABaseObject* self, void*){ + BOB_TRY + return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getU()); + BOB_CATCH_MEMBER("``u`` could not be read", 0) +} +int PyBobLearnMiscJFABase_setU(PyBobLearnMiscJFABaseObject* self, PyObject* value, void*){ + BOB_TRY + PyBlitzArrayObject* o; + if (!PyBlitzArray_Converter(value, &o)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a 2D array of floats", Py_TYPE(self)->tp_name, U.name()); + return -1; + } + auto o_ = make_safe(o); + auto b = PyBlitzArrayCxx_AsBlitz<double,2>(o, "u"); + if (!b) return -1; + self->cxx->setU(*b); + return 0; + BOB_CATCH_MEMBER("``u`` matrix could not be set", -1) +} + +/***** v *****/ +static auto V = bob::extension::VariableDoc( + "v", + "array_like <float, 2D>", + "Returns the V matrix (between client variability matrix)", + "" +); +PyObject* PyBobLearnMiscJFABase_getV(PyBobLearnMiscJFABaseObject* self, void*){ + BOB_TRY + return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getV()); + BOB_CATCH_MEMBER("``v`` could not be read", 0) +} +int PyBobLearnMiscJFABase_setV(PyBobLearnMiscJFABaseObject* self, PyObject* value, void*){ + BOB_TRY + PyBlitzArrayObject* o; + if (!PyBlitzArray_Converter(value, &o)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a 2D array of floats", Py_TYPE(self)->tp_name, V.name()); + return -1; + } + auto o_ = make_safe(o); + auto b = PyBlitzArrayCxx_AsBlitz<double,2>(o, "v"); + if (!b) return -1; + self->cxx->setV(*b); + return 0; + BOB_CATCH_MEMBER("``v`` matrix could not be set", -1) +} + + +/***** d *****/ +static auto D = bob::extension::VariableDoc( + "d", + "array_like <float, 1D>", + "Returns the diagonal matrix diag(d) (as a 1D vector)", + "" +); +PyObject* PyBobLearnMiscJFABase_getD(PyBobLearnMiscJFABaseObject* self, void*){ + BOB_TRY + return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getD()); + BOB_CATCH_MEMBER("``d`` could not be read", 0) +} +int PyBobLearnMiscJFABase_setD(PyBobLearnMiscJFABaseObject* self, PyObject* value, void*){ + BOB_TRY + PyBlitzArrayObject* o; + if (!PyBlitzArray_Converter(value, &o)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, D.name()); + return -1; + } + auto o_ = make_safe(o); + auto b = PyBlitzArrayCxx_AsBlitz<double,1>(o, "d"); + if (!b) return -1; + self->cxx->setD(*b); + return 0; + BOB_CATCH_MEMBER("``d`` matrix could not be set", -1) +} + + +/***** ubm *****/ +static auto ubm = bob::extension::VariableDoc( + "ubm", + ":py:class:`bob.learn.misc.GMMMachine`", + "Returns the UBM (Universal Background Model", + "" +); +PyObject* PyBobLearnMiscJFABase_getUBM(PyBobLearnMiscJFABaseObject* self, void*){ + BOB_TRY + + boost::shared_ptr<bob::learn::misc::GMMMachine> ubm_gmmMachine = self->cxx->getUbm(); + + //Allocating the correspondent python object + PyBobLearnMiscGMMMachineObject* retval = + (PyBobLearnMiscGMMMachineObject*)PyBobLearnMiscGMMMachine_Type.tp_alloc(&PyBobLearnMiscGMMMachine_Type, 0); + retval->cxx = ubm_gmmMachine; + + return Py_BuildValue("O",retval); + BOB_CATCH_MEMBER("ubm could not be read", 0) +} +int PyBobLearnMiscJFABase_setUBM(PyBobLearnMiscJFABaseObject* self, PyObject* value, void*){ + BOB_TRY + + if (!PyBobLearnMiscGMMMachine_Check(value)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMMachine`", Py_TYPE(self)->tp_name, ubm.name()); + return -1; + } + + PyBobLearnMiscGMMMachineObject* ubm_gmmMachine = 0; + PyArg_Parse(value, "O!", &PyBobLearnMiscGMMMachine_Type,&ubm_gmmMachine); + + self->cxx->setUbm(ubm_gmmMachine->cxx); + + return 0; + BOB_CATCH_MEMBER("ubm could not be set", -1) +} + + + + +static PyGetSetDef PyBobLearnMiscJFABase_getseters[] = { + { + shape.name(), + (getter)PyBobLearnMiscJFABase_getShape, + 0, + shape.doc(), + 0 + }, + + { + supervector_length.name(), + (getter)PyBobLearnMiscJFABase_getSupervectorLength, + 0, + supervector_length.doc(), + 0 + }, + + { + U.name(), + (getter)PyBobLearnMiscJFABase_getU, + (setter)PyBobLearnMiscJFABase_setU, + U.doc(), + 0 + }, + + { + V.name(), + (getter)PyBobLearnMiscJFABase_getV, + (setter)PyBobLearnMiscJFABase_setV, + V.doc(), + 0 + }, + + { + D.name(), + (getter)PyBobLearnMiscJFABase_getD, + (setter)PyBobLearnMiscJFABase_setD, + D.doc(), + 0 + }, + + { + ubm.name(), + (getter)PyBobLearnMiscJFABase_getUBM, + (setter)PyBobLearnMiscJFABase_setUBM, + ubm.doc(), + 0 + }, + + + {0} // Sentinel +}; + + +/******************************************************************/ +/************ Functions Section ***********************************/ +/******************************************************************/ + + +/*** save ***/ +static auto save = bob::extension::FunctionDoc( + "save", + "Save the configuration of the JFABase to a given HDF5 file" +) +.add_prototype("hdf5") +.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for writing"); +static PyObject* PyBobLearnMiscJFABase_Save(PyBobLearnMiscJFABaseObject* self, PyObject* args, PyObject* kwargs) { + + BOB_TRY + + // get list of arguments + char** kwlist = save.kwlist(0); + PyBobIoHDF5FileObject* hdf5; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, PyBobIoHDF5File_Converter, &hdf5)) return 0; + + auto hdf5_ = make_safe(hdf5); + self->cxx->save(*hdf5->f); + + BOB_CATCH_MEMBER("cannot save the data", 0) + Py_RETURN_NONE; +} + +/*** load ***/ +static auto load = bob::extension::FunctionDoc( + "load", + "Load the configuration of the JFABase to a given HDF5 file" +) +.add_prototype("hdf5") +.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for reading"); +static PyObject* PyBobLearnMiscJFABase_Load(PyBobLearnMiscJFABaseObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + char** kwlist = load.kwlist(0); + PyBobIoHDF5FileObject* hdf5; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, PyBobIoHDF5File_Converter, &hdf5)) return 0; + + auto hdf5_ = make_safe(hdf5); + self->cxx->load(*hdf5->f); + + BOB_CATCH_MEMBER("cannot load the data", 0) + Py_RETURN_NONE; +} + + +/*** is_similar_to ***/ +static auto is_similar_to = bob::extension::FunctionDoc( + "is_similar_to", + + "Compares this JFABase with the ``other`` one to be approximately the same.", + "The optional values ``r_epsilon`` and ``a_epsilon`` refer to the " + "relative and absolute precision for the ``weights``, ``biases`` " + "and any other values internal to this machine." +) +.add_prototype("other, [r_epsilon], [a_epsilon]","output") +.add_parameter("other", ":py:class:`bob.learn.misc.JFABase`", "A JFABase object to be compared.") +.add_parameter("r_epsilon", "float", "Relative precision.") +.add_parameter("a_epsilon", "float", "Absolute precision.") +.add_return("output","bool","True if it is similar, otherwise false."); +static PyObject* PyBobLearnMiscJFABase_IsSimilarTo(PyBobLearnMiscJFABaseObject* self, PyObject* args, PyObject* kwds) { + + /* Parses input arguments in a single shot */ + char** kwlist = is_similar_to.kwlist(0); + + //PyObject* other = 0; + PyBobLearnMiscJFABaseObject* other = 0; + double r_epsilon = 1.e-5; + double a_epsilon = 1.e-8; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|dd", kwlist, + &PyBobLearnMiscJFABase_Type, &other, + &r_epsilon, &a_epsilon)){ + + is_similar_to.print_usage(); + return 0; + } + + if (self->cxx->is_similar_to(*other->cxx, r_epsilon, a_epsilon)) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; +} + + +/*** resize ***/ +static auto resize = bob::extension::FunctionDoc( + "resize", + "Allocates space for the statistics and resets to zero.", + 0, + true +) +.add_prototype("n_gaussians,n_inputs") +.add_parameter("n_gaussians", "int", "Number of gaussians") +.add_parameter("n_inputs", "int", "Dimensionality of the feature vector"); +static PyObject* PyBobLearnMiscJFABase_resize(PyBobLearnMiscJFABaseObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = resize.kwlist(0); + + int n_gaussians = 0; + int n_inputs = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ii", kwlist, &n_gaussians, &n_inputs)) Py_RETURN_NONE; + + if (n_gaussians <= 0){ + PyErr_Format(PyExc_TypeError, "n_gaussians must be greater than zero"); + resize.print_usage(); + return 0; + } + if (n_inputs <= 0){ + PyErr_Format(PyExc_TypeError, "n_inputs must be greater than zero"); + resize.print_usage(); + return 0; + } + self->cxx->resize(n_gaussians, n_inputs); + + BOB_CATCH_MEMBER("cannot perform the resize method", 0) + + Py_RETURN_NONE; +} + + + + +static PyMethodDef PyBobLearnMiscJFABase_methods[] = { + { + save.name(), + (PyCFunction)PyBobLearnMiscJFABase_Save, + METH_VARARGS|METH_KEYWORDS, + save.doc() + }, + { + load.name(), + (PyCFunction)PyBobLearnMiscJFABase_Load, + METH_VARARGS|METH_KEYWORDS, + load.doc() + }, + { + is_similar_to.name(), + (PyCFunction)PyBobLearnMiscJFABase_IsSimilarTo, + METH_VARARGS|METH_KEYWORDS, + is_similar_to.doc() + }, + { + resize.name(), + (PyCFunction)PyBobLearnMiscJFABase_resize, + METH_VARARGS|METH_KEYWORDS, + resize.doc() + }, + + {0} /* Sentinel */ +}; + + +/******************************************************************/ +/************ Module Section **************************************/ +/******************************************************************/ + +// Define the JFA type struct; will be initialized later +PyTypeObject PyBobLearnMiscJFABase_Type = { + PyVarObject_HEAD_INIT(0,0) + 0 +}; + +bool init_BobLearnMiscJFABase(PyObject* module) +{ + // initialize the type struct + PyBobLearnMiscJFABase_Type.tp_name = JFABase_doc.name(); + PyBobLearnMiscJFABase_Type.tp_basicsize = sizeof(PyBobLearnMiscJFABaseObject); + PyBobLearnMiscJFABase_Type.tp_flags = Py_TPFLAGS_DEFAULT; + PyBobLearnMiscJFABase_Type.tp_doc = JFABase_doc.doc(); + + // set the functions + PyBobLearnMiscJFABase_Type.tp_new = PyType_GenericNew; + PyBobLearnMiscJFABase_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscJFABase_init); + PyBobLearnMiscJFABase_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscJFABase_delete); + PyBobLearnMiscJFABase_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscJFABase_RichCompare); + PyBobLearnMiscJFABase_Type.tp_methods = PyBobLearnMiscJFABase_methods; + PyBobLearnMiscJFABase_Type.tp_getset = PyBobLearnMiscJFABase_getseters; + //PyBobLearnMiscJFABase_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscJFABase_forward); + + + // check that everything is fine + if (PyType_Ready(&PyBobLearnMiscJFABase_Type) < 0) return false; + + // add the type to the module + Py_INCREF(&PyBobLearnMiscJFABase_Type); + return PyModule_AddObject(module, "JFABase", (PyObject*)&PyBobLearnMiscJFABase_Type) >= 0; +} + diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp index 94b84f9..6456cdc 100644 --- a/bob/learn/misc/main.cpp +++ b/bob/learn/misc/main.cpp @@ -48,6 +48,7 @@ static PyObject* create_module (void) { if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0; if (!init_BobLearnMiscMLGMMTrainer(module)) return 0; if (!init_BobLearnMiscMAPGMMTrainer(module)) return 0; + if (!init_BobLearnMiscJFABase(module)) return 0; static void* PyBobLearnMisc_API[PyBobLearnMisc_API_pointers]; diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h index a908307..64b501e 100644 --- a/bob/learn/misc/main.h +++ b/bob/learn/misc/main.h @@ -27,6 +27,8 @@ #include <bob.learn.misc/ML_GMMTrainer.h> #include <bob.learn.misc/MAP_GMMTrainer.h> +#include <bob.learn.misc/JFABase.h> + #if PY_VERSION_HEX >= 0x03000000 #define PyInt_Check PyLong_Check @@ -154,4 +156,15 @@ bool init_BobLearnMiscMAPGMMTrainer(PyObject* module); int PyBobLearnMiscMAPGMMTrainer_Check(PyObject* o); +// JFABase +typedef struct { + PyObject_HEAD + boost::shared_ptr<bob::learn::misc::JFABase> cxx; +} PyBobLearnMiscJFABaseObject; + +extern PyTypeObject PyBobLearnMiscJFABase_Type; +bool init_BobLearnMiscJFABase(PyObject* module); +int PyBobLearnMiscJFABase_Check(PyObject* o); + + #endif // BOB_LEARN_EM_MAIN_H -- GitLab