From 84c1a36db3150edcb8e35954922e8a0d53bb1acd Mon Sep 17 00:00:00 2001
From: Manuel Guenther <manuel.guenther@idiap.ch>
Date: Fri, 6 Mar 2015 14:49:06 +0100
Subject: [PATCH] Made the GMMStats of the GMM training accessible.

---
 bob/learn/em/cpp/GMMBaseTrainer.cpp           | 20 +++++----
 bob/learn/em/cpp/MAP_GMMTrainer.cpp           | 23 +++++------
 bob/learn/em/cpp/ML_GMMTrainer.cpp            | 12 +++---
 .../em/include/bob.learn.em/GMMBaseTrainer.h  | 22 +++++-----
 .../em/include/bob.learn.em/MAP_GMMTrainer.h  | 16 ++++----
 .../em/include/bob.learn.em/ML_GMMTrainer.h   | 10 +++--
 bob/learn/em/map_gmm_trainer.cpp              | 41 +++++++++++++++----
 bob/learn/em/ml_gmm_trainer.cpp               | 36 +++++++++++++++-
 8 files changed, 122 insertions(+), 58 deletions(-)

diff --git a/bob/learn/em/cpp/GMMBaseTrainer.cpp b/bob/learn/em/cpp/GMMBaseTrainer.cpp
index 9351975..a48a520 100644
--- a/bob/learn/em/cpp/GMMBaseTrainer.cpp
+++ b/bob/learn/em/cpp/GMMBaseTrainer.cpp
@@ -12,12 +12,14 @@
 bob::learn::em::GMMBaseTrainer::GMMBaseTrainer(const bool update_means,
     const bool update_variances, const bool update_weights,
     const double mean_var_update_responsibilities_threshold):
+  m_ss(new bob::learn::em::GMMStats()),
   m_update_means(update_means), m_update_variances(update_variances),
   m_update_weights(update_weights),
   m_mean_var_update_responsibilities_threshold(mean_var_update_responsibilities_threshold)
 {}
 
 bob::learn::em::GMMBaseTrainer::GMMBaseTrainer(const bob::learn::em::GMMBaseTrainer& b):
+  m_ss(new bob::learn::em::GMMStats()),
   m_update_means(b.m_update_means), m_update_variances(b.m_update_variances),
   m_mean_var_update_responsibilities_threshold(b.m_mean_var_update_responsibilities_threshold)
 {}
@@ -28,20 +30,20 @@ bob::learn::em::GMMBaseTrainer::~GMMBaseTrainer()
 void bob::learn::em::GMMBaseTrainer::initialize(bob::learn::em::GMMMachine& gmm)
 {
   // Allocate memory for the sufficient statistics and initialise
-  m_ss.resize(gmm.getNGaussians(),gmm.getNInputs());
+  m_ss->resize(gmm.getNGaussians(),gmm.getNInputs());
 }
 
 void bob::learn::em::GMMBaseTrainer::eStep(bob::learn::em::GMMMachine& gmm,
   const blitz::Array<double,2>& data)
 {
-  m_ss.init();
+  m_ss->init();
   // Calculate the sufficient statistics and save in m_ss
-  gmm.accStatistics(data, m_ss);
+  gmm.accStatistics(data, *m_ss);
 }
 
 double bob::learn::em::GMMBaseTrainer::computeLikelihood(bob::learn::em::GMMMachine& gmm)
 {
-  return m_ss.log_likelihood / m_ss.T;
+  return m_ss->log_likelihood / m_ss->T;
 }
 
 
@@ -50,7 +52,7 @@ bob::learn::em::GMMBaseTrainer& bob::learn::em::GMMBaseTrainer::operator=
 {
   if (this != &other)
   {
-    m_ss = other.m_ss;
+    *m_ss = *other.m_ss;
     m_update_means = other.m_update_means;
     m_update_variances = other.m_update_variances;
     m_update_weights = other.m_update_weights;
@@ -62,7 +64,7 @@ bob::learn::em::GMMBaseTrainer& bob::learn::em::GMMBaseTrainer::operator=
 bool bob::learn::em::GMMBaseTrainer::operator==
   (const bob::learn::em::GMMBaseTrainer &other) const
 {
-  return m_ss == other.m_ss &&
+  return *m_ss == *other.m_ss &&
          m_update_means == other.m_update_means &&
          m_update_variances == other.m_update_variances &&
          m_update_weights == other.m_update_weights &&
@@ -79,7 +81,7 @@ bool bob::learn::em::GMMBaseTrainer::is_similar_to
   (const bob::learn::em::GMMBaseTrainer &other, const double r_epsilon,
    const double a_epsilon) const
 {
-  return m_ss == other.m_ss &&
+  return *m_ss == *other.m_ss &&
          m_update_means == other.m_update_means &&
          m_update_variances == other.m_update_variances &&
          m_update_weights == other.m_update_weights &&
@@ -87,8 +89,8 @@ bool bob::learn::em::GMMBaseTrainer::is_similar_to
           other.m_mean_var_update_responsibilities_threshold, r_epsilon, a_epsilon);
 }
 
-void bob::learn::em::GMMBaseTrainer::setGMMStats(const bob::learn::em::GMMStats& stats)
+void bob::learn::em::GMMBaseTrainer::setGMMStats(boost::shared_ptr<bob::learn::em::GMMStats> stats)
 {
-  bob::core::array::assertSameShape(m_ss.sumPx, stats.sumPx);
+  bob::core::array::assertSameShape(m_ss->sumPx, stats->sumPx);
   m_ss = stats;
 }
diff --git a/bob/learn/em/cpp/MAP_GMMTrainer.cpp b/bob/learn/em/cpp/MAP_GMMTrainer.cpp
index 7dcf504..fc9a82e 100644
--- a/bob/learn/em/cpp/MAP_GMMTrainer.cpp
+++ b/bob/learn/em/cpp/MAP_GMMTrainer.cpp
@@ -14,8 +14,8 @@ bob::learn::em::MAP_GMMTrainer::MAP_GMMTrainer(
    const bool update_weights,
    const double mean_var_update_responsibilities_threshold,
 
-   const bool reynolds_adaptation, 
-   const double relevance_factor, 
+   const bool reynolds_adaptation,
+   const double relevance_factor,
    const double alpha,
    boost::shared_ptr<bob::learn::em::GMMMachine> prior_gmm):
 
@@ -33,7 +33,7 @@ bob::learn::em::MAP_GMMTrainer::MAP_GMMTrainer(const bob::learn::em::MAP_GMMTrai
   m_prior_gmm(b.m_prior_gmm)
 {
   m_relevance_factor    = b.m_relevance_factor;
-  m_alpha               = b.m_alpha; 
+  m_alpha               = b.m_alpha;
   m_reynolds_adaptation = b.m_reynolds_adaptation;
 }
 
@@ -75,10 +75,10 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
 {
   // Read options and variables
   double n_gaussians = gmm.getNGaussians();
-  
+
   //Checking if it is necessary to resize the cache
   if((size_t)m_cache_alpha.extent(0) != n_gaussians)
-    initialize(gmm); //If it is different for some reason, there is no way, you have to initialize  
+    initialize(gmm); //If it is different for some reason, there is no way, you have to initialize
 
   // Check that the prior GMM has been specified
   if (!m_prior_gmm)
@@ -92,13 +92,13 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
   if (!m_reynolds_adaptation)
     m_cache_alpha = m_alpha;
   else
-    m_cache_alpha = m_gmm_base_trainer.getGMMStats().n(i) / (m_gmm_base_trainer.getGMMStats().n(i) + m_relevance_factor);
+    m_cache_alpha = m_gmm_base_trainer.getGMMStats()->n(i) / (m_gmm_base_trainer.getGMMStats()->n(i) + m_relevance_factor);
 
   // - Update weights if requested
   //   Equation 11 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000
   if (m_gmm_base_trainer.getUpdateWeights()) {
     // Calculate the maximum likelihood weights
-    m_cache_ml_weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx
+    m_cache_ml_weights = m_gmm_base_trainer.getGMMStats()->n / static_cast<double>(m_gmm_base_trainer.getGMMStats()->T); //cast req. for linux/32-bits & osx
 
     // Get the prior weights
     const blitz::Array<double,1>& prior_weights = m_prior_gmm->getWeights();
@@ -123,12 +123,12 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
     for (size_t i=0; i<n_gaussians; ++i) {
       const blitz::Array<double,1>& prior_means = m_prior_gmm->getGaussian(i)->getMean();
       blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
-      if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
+      if (m_gmm_base_trainer.getGMMStats()->n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
         means = prior_means;
       }
       else {
         // Use the maximum likelihood means
-        means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats().sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i)) + (1-m_cache_alpha(i)) * prior_means;
+        means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats()->sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats()->n(i)) + (1-m_cache_alpha(i)) * prior_means;
       }
     }
   }
@@ -142,11 +142,11 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
       blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
       const blitz::Array<double,1>& prior_variances = m_prior_gmm->getGaussian(i)->getVariance();
       blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
-      if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
+      if (m_gmm_base_trainer.getGMMStats()->n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
         variances = (prior_variances + prior_means) - blitz::pow2(means);
       }
       else {
-        variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats().sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means);
+        variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats()->sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats()->n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means);
       }
       gmm.getGaussian(i)->applyVarianceThresholds();
     }
@@ -200,4 +200,3 @@ bool bob::learn::em::MAP_GMMTrainer::is_similar_to
          bob::core::isClose(m_alpha, other.m_alpha, r_epsilon, a_epsilon) &&
          m_reynolds_adaptation == other.m_reynolds_adaptation;
 }
-
diff --git a/bob/learn/em/cpp/ML_GMMTrainer.cpp b/bob/learn/em/cpp/ML_GMMTrainer.cpp
index 79f7802..9ac7102 100644
--- a/bob/learn/em/cpp/ML_GMMTrainer.cpp
+++ b/bob/learn/em/cpp/ML_GMMTrainer.cpp
@@ -10,7 +10,7 @@
 
 bob::learn::em::ML_GMMTrainer::ML_GMMTrainer(
    const bool update_means,
-   const bool update_variances, 
+   const bool update_variances,
    const bool update_weights,
    const double mean_var_update_responsibilities_threshold
 ):
@@ -29,7 +29,7 @@ bob::learn::em::ML_GMMTrainer::~ML_GMMTrainer()
 void bob::learn::em::ML_GMMTrainer::initialize(bob::learn::em::GMMMachine& gmm)
 {
   m_gmm_base_trainer.initialize(gmm);
-  
+
   // Allocate cache
   size_t n_gaussians = gmm.getNGaussians();
   m_cache_ss_n_thresholded.resize(n_gaussians);
@@ -49,14 +49,14 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
   //   Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006
   if (m_gmm_base_trainer.getUpdateWeights()) {
     blitz::Array<double,1>& weights = gmm.updateWeights();
-    weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx
+    weights = m_gmm_base_trainer.getGMMStats()->n / static_cast<double>(m_gmm_base_trainer.getGMMStats()->T); //cast req. for linux/32-bits & osx
     // Recompute the log weights in the cache of the GMMMachine
     gmm.recomputeLogWeights();
   }
 
   // Generate a thresholded version of m_ss.n
   for(size_t i=0; i<n_gaussians; ++i)
-    m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats().n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold());
+    m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats()->n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold());
 
   // Update GMM parameters using the sufficient statistics (m_ss)
   // - Update means if requested
@@ -64,7 +64,7 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
   if (m_gmm_base_trainer.getUpdateMeans()) {
     for(size_t i=0; i<n_gaussians; ++i) {
       blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
-      means = m_gmm_base_trainer.getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
+      means = m_gmm_base_trainer.getGMMStats()->sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
     }
   }
 
@@ -77,7 +77,7 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
     for(size_t i=0; i<n_gaussians; ++i) {
       const blitz::Array<double,1>& means = gmm.getGaussian(i)->getMean();
       blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
-      variances = m_gmm_base_trainer.getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
+      variances = m_gmm_base_trainer.getGMMStats()->sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
       gmm.getGaussian(i)->applyVarianceThresholds();
     }
   }
diff --git a/bob/learn/em/include/bob.learn.em/GMMBaseTrainer.h b/bob/learn/em/include/bob.learn.em/GMMBaseTrainer.h
index 121cdc2..9160c6b 100644
--- a/bob/learn/em/include/bob.learn.em/GMMBaseTrainer.h
+++ b/bob/learn/em/include/bob.learn.em/GMMBaseTrainer.h
@@ -30,7 +30,7 @@ class GMMBaseTrainer
      * @brief Default constructor
      */
     GMMBaseTrainer(const bool update_means=true,
-                   const bool update_variances=false, 
+                   const bool update_variances=false,
                    const bool update_weights=false,
                    const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon());
 
@@ -93,21 +93,21 @@ class GMMBaseTrainer
      * @brief Returns the internal GMM statistics. Useful to parallelize the
      * E-step
      */
-    const bob::learn::em::GMMStats getGMMStats() const
+    const boost::shared_ptr<bob::learn::em::GMMStats> getGMMStats() const
     { return m_ss; }
 
     /**
      * @brief Sets the internal GMM statistics. Useful to parallelize the
      * E-step
      */
-    void setGMMStats(const bob::learn::em::GMMStats& stats);
-    
+    void setGMMStats(boost::shared_ptr<bob::learn::em::GMMStats> stats);
+
     /**
      * update means on each iteration
-     */    
+     */
     bool getUpdateMeans()
     {return m_update_means;}
-    
+
     /**
      * update variances on each iteration
      */
@@ -117,19 +117,19 @@ class GMMBaseTrainer
 
     bool getUpdateWeights()
     {return m_update_weights;}
-    
-    
+
+
     double getMeanVarUpdateResponsibilitiesThreshold()
     {return m_mean_var_update_responsibilities_threshold;}
-    
+
 
   private:
-  
+
     /**
      * These are the sufficient statistics, calculated during the
      * E-step and used during the M-step
      */
-    bob::learn::em::GMMStats m_ss;
+    boost::shared_ptr<bob::learn::em::GMMStats> m_ss;
 
 
     /**
diff --git a/bob/learn/em/include/bob.learn.em/MAP_GMMTrainer.h b/bob/learn/em/include/bob.learn.em/MAP_GMMTrainer.h
index 7381e47..aac7087 100644
--- a/bob/learn/em/include/bob.learn.em/MAP_GMMTrainer.h
+++ b/bob/learn/em/include/bob.learn.em/MAP_GMMTrainer.h
@@ -28,11 +28,11 @@ class MAP_GMMTrainer
      */
     MAP_GMMTrainer(
       const bool update_means=true,
-      const bool update_variances=false, 
+      const bool update_variances=false,
       const bool update_weights=false,
       const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(),
-      const bool reynolds_adaptation=false, 
-      const double relevance_factor=4, 
+      const bool reynolds_adaptation=false,
+      const double relevance_factor=4,
       const double alpha=0.5,
       boost::shared_ptr<bob::learn::em::GMMMachine> prior_gmm = boost::shared_ptr<bob::learn::em::GMMMachine>());
 
@@ -108,14 +108,14 @@ class MAP_GMMTrainer
      */
     double computeLikelihood(bob::learn::em::GMMMachine& gmm){
       return m_gmm_base_trainer.computeLikelihood(gmm);
-    }    
-    
+    }
+
     bool getReynoldsAdaptation()
     {return m_reynolds_adaptation;}
 
     void setReynoldsAdaptation(const bool reynolds_adaptation)
     {m_reynolds_adaptation = reynolds_adaptation;}
-    
+
 
     double getRelevanceFactor()
     {return m_relevance_factor;}
@@ -130,6 +130,8 @@ class MAP_GMMTrainer
     void setAlpha(const double alpha)
     {m_alpha = alpha;}
 
+    bob::learn::em::GMMBaseTrainer& base_trainer(){return m_gmm_base_trainer;}
+
 
   protected:
 
@@ -140,7 +142,7 @@ class MAP_GMMTrainer
 
     /**
     Base Trainer for the MAP algorithm. Basically implements the e-step
-    */ 
+    */
     bob::learn::em::GMMBaseTrainer m_gmm_base_trainer;
 
     /**
diff --git a/bob/learn/em/include/bob.learn.em/ML_GMMTrainer.h b/bob/learn/em/include/bob.learn.em/ML_GMMTrainer.h
index 58e8a76..40f1e70 100644
--- a/bob/learn/em/include/bob.learn.em/ML_GMMTrainer.h
+++ b/bob/learn/em/include/bob.learn.em/ML_GMMTrainer.h
@@ -28,7 +28,7 @@ class ML_GMMTrainer{
      * @brief Default constructor
      */
     ML_GMMTrainer(const bool update_means=true,
-                  const bool update_variances=false, 
+                  const bool update_variances=false,
                   const bool update_weights=false,
                   const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon());
 
@@ -97,13 +97,15 @@ class ML_GMMTrainer{
      */
     bool is_similar_to(const ML_GMMTrainer& b, const double r_epsilon=1e-5,
       const double a_epsilon=1e-8) const;
-      
-    
+
+
+    bob::learn::em::GMMBaseTrainer& base_trainer(){return m_gmm_base_trainer;}
+
   protected:
 
     /**
     Base Trainer for the MAP algorithm. Basically implements the e-step
-    */ 
+    */
     bob::learn::em::GMMBaseTrainer m_gmm_base_trainer;
 
 
diff --git a/bob/learn/em/map_gmm_trainer.cpp b/bob/learn/em/map_gmm_trainer.cpp
index a4fdfd4..37a220d 100644
--- a/bob/learn/em/map_gmm_trainer.cpp
+++ b/bob/learn/em/map_gmm_trainer.cpp
@@ -244,6 +244,30 @@ int PyBobLearnEMMAPGMMTrainer_setAlpha(PyBobLearnEMMAPGMMTrainerObject* self, Py
 }
 
 
+static auto gmm_statistics = bob::extension::VariableDoc(
+  "gmm_statistics",
+  ":py:class:`GMMStats`",
+  "The GMM statistics that were used internally in the E- and M-steps",
+  "Setting and getting the internal GMM statistics might be useful to parallelize the GMM training."
+);
+PyObject* PyBobLearnEMMAPGMMTrainer_get_gmm_statistics(PyBobLearnEMMAPGMMTrainerObject* self, void*){
+  BOB_TRY
+  PyBobLearnEMGMMStatsObject* stats = (PyBobLearnEMGMMStatsObject*)PyBobLearnEMGMMStats_Type.tp_alloc(&PyBobLearnEMGMMStats_Type, 0);
+  stats->cxx = self->cxx->base_trainer().getGMMStats();
+  return Py_BuildValue("N", stats);
+  BOB_CATCH_MEMBER("gmm_statistics could not be read", 0)
+}
+int PyBobLearnEMMAPGMMTrainer_set_gmm_statistics(PyBobLearnEMMAPGMMTrainerObject* self, PyObject* value, void*){
+  BOB_TRY
+  if (!PyBobLearnEMGMMStats_Check(value)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects a GMMStats object", Py_TYPE(self)->tp_name, gmm_statistics.name());
+    return -1;
+  }
+  self->cxx->base_trainer().setGMMStats(reinterpret_cast<PyBobLearnEMGMMStatsObject*>(value)->cxx);
+  return 0;
+  BOB_CATCH_MEMBER("gmm_statistics could not be set", -1)
+}
+
 
 static PyGetSetDef PyBobLearnEMMAPGMMTrainer_getseters[] = {
   {
@@ -260,7 +284,13 @@ static PyGetSetDef PyBobLearnEMMAPGMMTrainer_getseters[] = {
     relevance_factor.doc(),
     0
   },
-
+  {
+    gmm_statistics.name(),
+    (getter)PyBobLearnEMMAPGMMTrainer_get_gmm_statistics,
+    (setter)PyBobLearnEMMAPGMMTrainer_set_gmm_statistics,
+    gmm_statistics.doc(),
+    0
+  },
   {0}  // Sentinel
 };
 
@@ -303,12 +333,11 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_initialize(PyBobLearnEMMAPGMMTrainerO
 /*** e_step ***/
 static auto e_step = bob::extension::FunctionDoc(
   "e_step",
-  "Calculates and saves statistics across the dataset,"
-  "and saves these as m_ss. ",
+  "Calculates and saves statistics across the dataset and saves these as :py:attr`gmm_statistics`. ",
 
   "Calculates the average log likelihood of the observations given the GMM,"
   "and returns this in average_log_likelihood."
-  "The statistics, m_ss, will be used in the m_step() that follows.",
+  "The statistics, :py:attr`gmm_statistics`, will be used in the :py:meth:`m_step` that follows.",
 
   true
 )
@@ -357,9 +386,7 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_e_step(PyBobLearnEMMAPGMMTrainerObjec
 static auto m_step = bob::extension::FunctionDoc(
   "m_step",
 
-   "Performs a maximum a posteriori (MAP) update of the GMM:"
-   "* parameters using the accumulated statistics in :py:class:`bob.learn.em.GMMBaseTrainer.m_ss` and the"
-   "* parameters of the prior model",
+   "Performs a maximum a posteriori (MAP) update of the GMM parameters using the accumulated statistics in :py:attr:`gmm_statistics` and the parameters of the prior model",
   "",
   true
 )
diff --git a/bob/learn/em/ml_gmm_trainer.cpp b/bob/learn/em/ml_gmm_trainer.cpp
index aa2e95d..2b30ec9 100644
--- a/bob/learn/em/ml_gmm_trainer.cpp
+++ b/bob/learn/em/ml_gmm_trainer.cpp
@@ -144,7 +144,39 @@ static PyObject* PyBobLearnEMMLGMMTrainer_RichCompare(PyBobLearnEMMLGMMTrainerOb
 /************ Variables Section ***********************************/
 /******************************************************************/
 
+static auto gmm_statistics = bob::extension::VariableDoc(
+  "gmm_statistics",
+  ":py:class:`GMMStats`",
+  "The GMM statistics that were used internally in the E- and M-steps",
+  "Setting and getting the internal GMM statistics might be useful to parallelize the GMM training."
+);
+PyObject* PyBobLearnEMMLGMMTrainer_get_gmm_statistics(PyBobLearnEMMLGMMTrainerObject* self, void*){
+  BOB_TRY
+  PyBobLearnEMGMMStatsObject* stats = (PyBobLearnEMGMMStatsObject*)PyBobLearnEMGMMStats_Type.tp_alloc(&PyBobLearnEMGMMStats_Type, 0);
+  stats->cxx = self->cxx->base_trainer().getGMMStats();
+  return Py_BuildValue("N", stats);
+  BOB_CATCH_MEMBER("gmm_statistics could not be read", 0)
+}
+int PyBobLearnEMMLGMMTrainer_set_gmm_statistics(PyBobLearnEMMLGMMTrainerObject* self, PyObject* value, void*){
+  BOB_TRY
+  if (!PyBobLearnEMGMMStats_Check(value)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects a GMMStats object", Py_TYPE(self)->tp_name, gmm_statistics.name());
+    return -1;
+  }
+  self->cxx->base_trainer().setGMMStats(reinterpret_cast<PyBobLearnEMGMMStatsObject*>(value)->cxx);
+  return 0;
+  BOB_CATCH_MEMBER("gmm_statistics could not be set", -1)
+}
+
+
 static PyGetSetDef PyBobLearnEMMLGMMTrainer_getseters[] = {
+  {
+   gmm_statistics.name(),
+   (getter)PyBobLearnEMMLGMMTrainer_get_gmm_statistics,
+   (setter)PyBobLearnEMMLGMMTrainer_set_gmm_statistics,
+   gmm_statistics.doc(),
+   0
+  },
   {0}  // Sentinel
 };
 
@@ -190,7 +222,7 @@ static auto e_step = bob::extension::FunctionDoc(
 
   "Calculates the average log likelihood of the observations given the GMM,"
   "and returns this in average_log_likelihood."
-  "The statistics, m_ss, will be used in the :py:func:`m_step` that follows.",
+  "The statistics, :py:attr:`gmm_statistics`, will be used in the :py:func:`m_step` that follows.",
 
   true
 )
@@ -237,7 +269,7 @@ static PyObject* PyBobLearnEMMLGMMTrainer_e_step(PyBobLearnEMMLGMMTrainerObject*
 static auto m_step = bob::extension::FunctionDoc(
   "m_step",
   "Performs a maximum likelihood (ML) update of the GMM parameters "
-  "using the accumulated statistics in :py:attr:`bob.learn.em.GMMBaseTrainer.m_ss`",
+  "using the accumulated statistics in :py:attr:`gmm_statistics`",
 
   "See Section 9.2.2 of Bishop, \"Pattern recognition and machine learning\", 2006",
 
-- 
GitLab