From 84c1a36db3150edcb8e35954922e8a0d53bb1acd Mon Sep 17 00:00:00 2001 From: Manuel Guenther <manuel.guenther@idiap.ch> Date: Fri, 6 Mar 2015 14:49:06 +0100 Subject: [PATCH] Made the GMMStats of the GMM training accessible. --- bob/learn/em/cpp/GMMBaseTrainer.cpp | 20 +++++---- bob/learn/em/cpp/MAP_GMMTrainer.cpp | 23 +++++------ bob/learn/em/cpp/ML_GMMTrainer.cpp | 12 +++--- .../em/include/bob.learn.em/GMMBaseTrainer.h | 22 +++++----- .../em/include/bob.learn.em/MAP_GMMTrainer.h | 16 ++++---- .../em/include/bob.learn.em/ML_GMMTrainer.h | 10 +++-- bob/learn/em/map_gmm_trainer.cpp | 41 +++++++++++++++---- bob/learn/em/ml_gmm_trainer.cpp | 36 +++++++++++++++- 8 files changed, 122 insertions(+), 58 deletions(-) diff --git a/bob/learn/em/cpp/GMMBaseTrainer.cpp b/bob/learn/em/cpp/GMMBaseTrainer.cpp index 9351975..a48a520 100644 --- a/bob/learn/em/cpp/GMMBaseTrainer.cpp +++ b/bob/learn/em/cpp/GMMBaseTrainer.cpp @@ -12,12 +12,14 @@ bob::learn::em::GMMBaseTrainer::GMMBaseTrainer(const bool update_means, const bool update_variances, const bool update_weights, const double mean_var_update_responsibilities_threshold): + m_ss(new bob::learn::em::GMMStats()), m_update_means(update_means), m_update_variances(update_variances), m_update_weights(update_weights), m_mean_var_update_responsibilities_threshold(mean_var_update_responsibilities_threshold) {} bob::learn::em::GMMBaseTrainer::GMMBaseTrainer(const bob::learn::em::GMMBaseTrainer& b): + m_ss(new bob::learn::em::GMMStats()), m_update_means(b.m_update_means), m_update_variances(b.m_update_variances), m_mean_var_update_responsibilities_threshold(b.m_mean_var_update_responsibilities_threshold) {} @@ -28,20 +30,20 @@ bob::learn::em::GMMBaseTrainer::~GMMBaseTrainer() void bob::learn::em::GMMBaseTrainer::initialize(bob::learn::em::GMMMachine& gmm) { // Allocate memory for the sufficient statistics and initialise - m_ss.resize(gmm.getNGaussians(),gmm.getNInputs()); + m_ss->resize(gmm.getNGaussians(),gmm.getNInputs()); } void bob::learn::em::GMMBaseTrainer::eStep(bob::learn::em::GMMMachine& gmm, const blitz::Array<double,2>& data) { - m_ss.init(); + m_ss->init(); // Calculate the sufficient statistics and save in m_ss - gmm.accStatistics(data, m_ss); + gmm.accStatistics(data, *m_ss); } double bob::learn::em::GMMBaseTrainer::computeLikelihood(bob::learn::em::GMMMachine& gmm) { - return m_ss.log_likelihood / m_ss.T; + return m_ss->log_likelihood / m_ss->T; } @@ -50,7 +52,7 @@ bob::learn::em::GMMBaseTrainer& bob::learn::em::GMMBaseTrainer::operator= { if (this != &other) { - m_ss = other.m_ss; + *m_ss = *other.m_ss; m_update_means = other.m_update_means; m_update_variances = other.m_update_variances; m_update_weights = other.m_update_weights; @@ -62,7 +64,7 @@ bob::learn::em::GMMBaseTrainer& bob::learn::em::GMMBaseTrainer::operator= bool bob::learn::em::GMMBaseTrainer::operator== (const bob::learn::em::GMMBaseTrainer &other) const { - return m_ss == other.m_ss && + return *m_ss == *other.m_ss && m_update_means == other.m_update_means && m_update_variances == other.m_update_variances && m_update_weights == other.m_update_weights && @@ -79,7 +81,7 @@ bool bob::learn::em::GMMBaseTrainer::is_similar_to (const bob::learn::em::GMMBaseTrainer &other, const double r_epsilon, const double a_epsilon) const { - return m_ss == other.m_ss && + return *m_ss == *other.m_ss && m_update_means == other.m_update_means && m_update_variances == other.m_update_variances && m_update_weights == other.m_update_weights && @@ -87,8 +89,8 @@ bool bob::learn::em::GMMBaseTrainer::is_similar_to other.m_mean_var_update_responsibilities_threshold, r_epsilon, a_epsilon); } -void bob::learn::em::GMMBaseTrainer::setGMMStats(const bob::learn::em::GMMStats& stats) +void bob::learn::em::GMMBaseTrainer::setGMMStats(boost::shared_ptr<bob::learn::em::GMMStats> stats) { - bob::core::array::assertSameShape(m_ss.sumPx, stats.sumPx); + bob::core::array::assertSameShape(m_ss->sumPx, stats->sumPx); m_ss = stats; } diff --git a/bob/learn/em/cpp/MAP_GMMTrainer.cpp b/bob/learn/em/cpp/MAP_GMMTrainer.cpp index 7dcf504..fc9a82e 100644 --- a/bob/learn/em/cpp/MAP_GMMTrainer.cpp +++ b/bob/learn/em/cpp/MAP_GMMTrainer.cpp @@ -14,8 +14,8 @@ bob::learn::em::MAP_GMMTrainer::MAP_GMMTrainer( const bool update_weights, const double mean_var_update_responsibilities_threshold, - const bool reynolds_adaptation, - const double relevance_factor, + const bool reynolds_adaptation, + const double relevance_factor, const double alpha, boost::shared_ptr<bob::learn::em::GMMMachine> prior_gmm): @@ -33,7 +33,7 @@ bob::learn::em::MAP_GMMTrainer::MAP_GMMTrainer(const bob::learn::em::MAP_GMMTrai m_prior_gmm(b.m_prior_gmm) { m_relevance_factor = b.m_relevance_factor; - m_alpha = b.m_alpha; + m_alpha = b.m_alpha; m_reynolds_adaptation = b.m_reynolds_adaptation; } @@ -75,10 +75,10 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm) { // Read options and variables double n_gaussians = gmm.getNGaussians(); - + //Checking if it is necessary to resize the cache if((size_t)m_cache_alpha.extent(0) != n_gaussians) - initialize(gmm); //If it is different for some reason, there is no way, you have to initialize + initialize(gmm); //If it is different for some reason, there is no way, you have to initialize // Check that the prior GMM has been specified if (!m_prior_gmm) @@ -92,13 +92,13 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm) if (!m_reynolds_adaptation) m_cache_alpha = m_alpha; else - m_cache_alpha = m_gmm_base_trainer.getGMMStats().n(i) / (m_gmm_base_trainer.getGMMStats().n(i) + m_relevance_factor); + m_cache_alpha = m_gmm_base_trainer.getGMMStats()->n(i) / (m_gmm_base_trainer.getGMMStats()->n(i) + m_relevance_factor); // - Update weights if requested // Equation 11 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000 if (m_gmm_base_trainer.getUpdateWeights()) { // Calculate the maximum likelihood weights - m_cache_ml_weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx + m_cache_ml_weights = m_gmm_base_trainer.getGMMStats()->n / static_cast<double>(m_gmm_base_trainer.getGMMStats()->T); //cast req. for linux/32-bits & osx // Get the prior weights const blitz::Array<double,1>& prior_weights = m_prior_gmm->getWeights(); @@ -123,12 +123,12 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm) for (size_t i=0; i<n_gaussians; ++i) { const blitz::Array<double,1>& prior_means = m_prior_gmm->getGaussian(i)->getMean(); blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean(); - if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) { + if (m_gmm_base_trainer.getGMMStats()->n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) { means = prior_means; } else { // Use the maximum likelihood means - means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats().sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i)) + (1-m_cache_alpha(i)) * prior_means; + means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats()->sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats()->n(i)) + (1-m_cache_alpha(i)) * prior_means; } } } @@ -142,11 +142,11 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm) blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean(); const blitz::Array<double,1>& prior_variances = m_prior_gmm->getGaussian(i)->getVariance(); blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance(); - if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) { + if (m_gmm_base_trainer.getGMMStats()->n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) { variances = (prior_variances + prior_means) - blitz::pow2(means); } else { - variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats().sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means); + variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats()->sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats()->n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means); } gmm.getGaussian(i)->applyVarianceThresholds(); } @@ -200,4 +200,3 @@ bool bob::learn::em::MAP_GMMTrainer::is_similar_to bob::core::isClose(m_alpha, other.m_alpha, r_epsilon, a_epsilon) && m_reynolds_adaptation == other.m_reynolds_adaptation; } - diff --git a/bob/learn/em/cpp/ML_GMMTrainer.cpp b/bob/learn/em/cpp/ML_GMMTrainer.cpp index 79f7802..9ac7102 100644 --- a/bob/learn/em/cpp/ML_GMMTrainer.cpp +++ b/bob/learn/em/cpp/ML_GMMTrainer.cpp @@ -10,7 +10,7 @@ bob::learn::em::ML_GMMTrainer::ML_GMMTrainer( const bool update_means, - const bool update_variances, + const bool update_variances, const bool update_weights, const double mean_var_update_responsibilities_threshold ): @@ -29,7 +29,7 @@ bob::learn::em::ML_GMMTrainer::~ML_GMMTrainer() void bob::learn::em::ML_GMMTrainer::initialize(bob::learn::em::GMMMachine& gmm) { m_gmm_base_trainer.initialize(gmm); - + // Allocate cache size_t n_gaussians = gmm.getNGaussians(); m_cache_ss_n_thresholded.resize(n_gaussians); @@ -49,14 +49,14 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm) // Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006 if (m_gmm_base_trainer.getUpdateWeights()) { blitz::Array<double,1>& weights = gmm.updateWeights(); - weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx + weights = m_gmm_base_trainer.getGMMStats()->n / static_cast<double>(m_gmm_base_trainer.getGMMStats()->T); //cast req. for linux/32-bits & osx // Recompute the log weights in the cache of the GMMMachine gmm.recomputeLogWeights(); } // Generate a thresholded version of m_ss.n for(size_t i=0; i<n_gaussians; ++i) - m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats().n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()); + m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats()->n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()); // Update GMM parameters using the sufficient statistics (m_ss) // - Update means if requested @@ -64,7 +64,7 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm) if (m_gmm_base_trainer.getUpdateMeans()) { for(size_t i=0; i<n_gaussians; ++i) { blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean(); - means = m_gmm_base_trainer.getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i); + means = m_gmm_base_trainer.getGMMStats()->sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i); } } @@ -77,7 +77,7 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm) for(size_t i=0; i<n_gaussians; ++i) { const blitz::Array<double,1>& means = gmm.getGaussian(i)->getMean(); blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance(); - variances = m_gmm_base_trainer.getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means); + variances = m_gmm_base_trainer.getGMMStats()->sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means); gmm.getGaussian(i)->applyVarianceThresholds(); } } diff --git a/bob/learn/em/include/bob.learn.em/GMMBaseTrainer.h b/bob/learn/em/include/bob.learn.em/GMMBaseTrainer.h index 121cdc2..9160c6b 100644 --- a/bob/learn/em/include/bob.learn.em/GMMBaseTrainer.h +++ b/bob/learn/em/include/bob.learn.em/GMMBaseTrainer.h @@ -30,7 +30,7 @@ class GMMBaseTrainer * @brief Default constructor */ GMMBaseTrainer(const bool update_means=true, - const bool update_variances=false, + const bool update_variances=false, const bool update_weights=false, const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon()); @@ -93,21 +93,21 @@ class GMMBaseTrainer * @brief Returns the internal GMM statistics. Useful to parallelize the * E-step */ - const bob::learn::em::GMMStats getGMMStats() const + const boost::shared_ptr<bob::learn::em::GMMStats> getGMMStats() const { return m_ss; } /** * @brief Sets the internal GMM statistics. Useful to parallelize the * E-step */ - void setGMMStats(const bob::learn::em::GMMStats& stats); - + void setGMMStats(boost::shared_ptr<bob::learn::em::GMMStats> stats); + /** * update means on each iteration - */ + */ bool getUpdateMeans() {return m_update_means;} - + /** * update variances on each iteration */ @@ -117,19 +117,19 @@ class GMMBaseTrainer bool getUpdateWeights() {return m_update_weights;} - - + + double getMeanVarUpdateResponsibilitiesThreshold() {return m_mean_var_update_responsibilities_threshold;} - + private: - + /** * These are the sufficient statistics, calculated during the * E-step and used during the M-step */ - bob::learn::em::GMMStats m_ss; + boost::shared_ptr<bob::learn::em::GMMStats> m_ss; /** diff --git a/bob/learn/em/include/bob.learn.em/MAP_GMMTrainer.h b/bob/learn/em/include/bob.learn.em/MAP_GMMTrainer.h index 7381e47..aac7087 100644 --- a/bob/learn/em/include/bob.learn.em/MAP_GMMTrainer.h +++ b/bob/learn/em/include/bob.learn.em/MAP_GMMTrainer.h @@ -28,11 +28,11 @@ class MAP_GMMTrainer */ MAP_GMMTrainer( const bool update_means=true, - const bool update_variances=false, + const bool update_variances=false, const bool update_weights=false, const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(), - const bool reynolds_adaptation=false, - const double relevance_factor=4, + const bool reynolds_adaptation=false, + const double relevance_factor=4, const double alpha=0.5, boost::shared_ptr<bob::learn::em::GMMMachine> prior_gmm = boost::shared_ptr<bob::learn::em::GMMMachine>()); @@ -108,14 +108,14 @@ class MAP_GMMTrainer */ double computeLikelihood(bob::learn::em::GMMMachine& gmm){ return m_gmm_base_trainer.computeLikelihood(gmm); - } - + } + bool getReynoldsAdaptation() {return m_reynolds_adaptation;} void setReynoldsAdaptation(const bool reynolds_adaptation) {m_reynolds_adaptation = reynolds_adaptation;} - + double getRelevanceFactor() {return m_relevance_factor;} @@ -130,6 +130,8 @@ class MAP_GMMTrainer void setAlpha(const double alpha) {m_alpha = alpha;} + bob::learn::em::GMMBaseTrainer& base_trainer(){return m_gmm_base_trainer;} + protected: @@ -140,7 +142,7 @@ class MAP_GMMTrainer /** Base Trainer for the MAP algorithm. Basically implements the e-step - */ + */ bob::learn::em::GMMBaseTrainer m_gmm_base_trainer; /** diff --git a/bob/learn/em/include/bob.learn.em/ML_GMMTrainer.h b/bob/learn/em/include/bob.learn.em/ML_GMMTrainer.h index 58e8a76..40f1e70 100644 --- a/bob/learn/em/include/bob.learn.em/ML_GMMTrainer.h +++ b/bob/learn/em/include/bob.learn.em/ML_GMMTrainer.h @@ -28,7 +28,7 @@ class ML_GMMTrainer{ * @brief Default constructor */ ML_GMMTrainer(const bool update_means=true, - const bool update_variances=false, + const bool update_variances=false, const bool update_weights=false, const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon()); @@ -97,13 +97,15 @@ class ML_GMMTrainer{ */ bool is_similar_to(const ML_GMMTrainer& b, const double r_epsilon=1e-5, const double a_epsilon=1e-8) const; - - + + + bob::learn::em::GMMBaseTrainer& base_trainer(){return m_gmm_base_trainer;} + protected: /** Base Trainer for the MAP algorithm. Basically implements the e-step - */ + */ bob::learn::em::GMMBaseTrainer m_gmm_base_trainer; diff --git a/bob/learn/em/map_gmm_trainer.cpp b/bob/learn/em/map_gmm_trainer.cpp index a4fdfd4..37a220d 100644 --- a/bob/learn/em/map_gmm_trainer.cpp +++ b/bob/learn/em/map_gmm_trainer.cpp @@ -244,6 +244,30 @@ int PyBobLearnEMMAPGMMTrainer_setAlpha(PyBobLearnEMMAPGMMTrainerObject* self, Py } +static auto gmm_statistics = bob::extension::VariableDoc( + "gmm_statistics", + ":py:class:`GMMStats`", + "The GMM statistics that were used internally in the E- and M-steps", + "Setting and getting the internal GMM statistics might be useful to parallelize the GMM training." +); +PyObject* PyBobLearnEMMAPGMMTrainer_get_gmm_statistics(PyBobLearnEMMAPGMMTrainerObject* self, void*){ + BOB_TRY + PyBobLearnEMGMMStatsObject* stats = (PyBobLearnEMGMMStatsObject*)PyBobLearnEMGMMStats_Type.tp_alloc(&PyBobLearnEMGMMStats_Type, 0); + stats->cxx = self->cxx->base_trainer().getGMMStats(); + return Py_BuildValue("N", stats); + BOB_CATCH_MEMBER("gmm_statistics could not be read", 0) +} +int PyBobLearnEMMAPGMMTrainer_set_gmm_statistics(PyBobLearnEMMAPGMMTrainerObject* self, PyObject* value, void*){ + BOB_TRY + if (!PyBobLearnEMGMMStats_Check(value)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a GMMStats object", Py_TYPE(self)->tp_name, gmm_statistics.name()); + return -1; + } + self->cxx->base_trainer().setGMMStats(reinterpret_cast<PyBobLearnEMGMMStatsObject*>(value)->cxx); + return 0; + BOB_CATCH_MEMBER("gmm_statistics could not be set", -1) +} + static PyGetSetDef PyBobLearnEMMAPGMMTrainer_getseters[] = { { @@ -260,7 +284,13 @@ static PyGetSetDef PyBobLearnEMMAPGMMTrainer_getseters[] = { relevance_factor.doc(), 0 }, - + { + gmm_statistics.name(), + (getter)PyBobLearnEMMAPGMMTrainer_get_gmm_statistics, + (setter)PyBobLearnEMMAPGMMTrainer_set_gmm_statistics, + gmm_statistics.doc(), + 0 + }, {0} // Sentinel }; @@ -303,12 +333,11 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_initialize(PyBobLearnEMMAPGMMTrainerO /*** e_step ***/ static auto e_step = bob::extension::FunctionDoc( "e_step", - "Calculates and saves statistics across the dataset," - "and saves these as m_ss. ", + "Calculates and saves statistics across the dataset and saves these as :py:attr`gmm_statistics`. ", "Calculates the average log likelihood of the observations given the GMM," "and returns this in average_log_likelihood." - "The statistics, m_ss, will be used in the m_step() that follows.", + "The statistics, :py:attr`gmm_statistics`, will be used in the :py:meth:`m_step` that follows.", true ) @@ -357,9 +386,7 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_e_step(PyBobLearnEMMAPGMMTrainerObjec static auto m_step = bob::extension::FunctionDoc( "m_step", - "Performs a maximum a posteriori (MAP) update of the GMM:" - "* parameters using the accumulated statistics in :py:class:`bob.learn.em.GMMBaseTrainer.m_ss` and the" - "* parameters of the prior model", + "Performs a maximum a posteriori (MAP) update of the GMM parameters using the accumulated statistics in :py:attr:`gmm_statistics` and the parameters of the prior model", "", true ) diff --git a/bob/learn/em/ml_gmm_trainer.cpp b/bob/learn/em/ml_gmm_trainer.cpp index aa2e95d..2b30ec9 100644 --- a/bob/learn/em/ml_gmm_trainer.cpp +++ b/bob/learn/em/ml_gmm_trainer.cpp @@ -144,7 +144,39 @@ static PyObject* PyBobLearnEMMLGMMTrainer_RichCompare(PyBobLearnEMMLGMMTrainerOb /************ Variables Section ***********************************/ /******************************************************************/ +static auto gmm_statistics = bob::extension::VariableDoc( + "gmm_statistics", + ":py:class:`GMMStats`", + "The GMM statistics that were used internally in the E- and M-steps", + "Setting and getting the internal GMM statistics might be useful to parallelize the GMM training." +); +PyObject* PyBobLearnEMMLGMMTrainer_get_gmm_statistics(PyBobLearnEMMLGMMTrainerObject* self, void*){ + BOB_TRY + PyBobLearnEMGMMStatsObject* stats = (PyBobLearnEMGMMStatsObject*)PyBobLearnEMGMMStats_Type.tp_alloc(&PyBobLearnEMGMMStats_Type, 0); + stats->cxx = self->cxx->base_trainer().getGMMStats(); + return Py_BuildValue("N", stats); + BOB_CATCH_MEMBER("gmm_statistics could not be read", 0) +} +int PyBobLearnEMMLGMMTrainer_set_gmm_statistics(PyBobLearnEMMLGMMTrainerObject* self, PyObject* value, void*){ + BOB_TRY + if (!PyBobLearnEMGMMStats_Check(value)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a GMMStats object", Py_TYPE(self)->tp_name, gmm_statistics.name()); + return -1; + } + self->cxx->base_trainer().setGMMStats(reinterpret_cast<PyBobLearnEMGMMStatsObject*>(value)->cxx); + return 0; + BOB_CATCH_MEMBER("gmm_statistics could not be set", -1) +} + + static PyGetSetDef PyBobLearnEMMLGMMTrainer_getseters[] = { + { + gmm_statistics.name(), + (getter)PyBobLearnEMMLGMMTrainer_get_gmm_statistics, + (setter)PyBobLearnEMMLGMMTrainer_set_gmm_statistics, + gmm_statistics.doc(), + 0 + }, {0} // Sentinel }; @@ -190,7 +222,7 @@ static auto e_step = bob::extension::FunctionDoc( "Calculates the average log likelihood of the observations given the GMM," "and returns this in average_log_likelihood." - "The statistics, m_ss, will be used in the :py:func:`m_step` that follows.", + "The statistics, :py:attr:`gmm_statistics`, will be used in the :py:func:`m_step` that follows.", true ) @@ -237,7 +269,7 @@ static PyObject* PyBobLearnEMMLGMMTrainer_e_step(PyBobLearnEMMLGMMTrainerObject* static auto m_step = bob::extension::FunctionDoc( "m_step", "Performs a maximum likelihood (ML) update of the GMM parameters " - "using the accumulated statistics in :py:attr:`bob.learn.em.GMMBaseTrainer.m_ss`", + "using the accumulated statistics in :py:attr:`gmm_statistics`", "See Section 9.2.2 of Bishop, \"Pattern recognition and machine learning\", 2006", -- GitLab