From 1927b88be168d737ad5d94cc21be6005193ad5b0 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 23 Jan 2015 16:40:21 +0100
Subject: [PATCH] Mapped ML_GMMTrainer

---
 bob/learn/misc/ML_gmm_trainer.cpp             | 304 ++++++++++++++++++
 bob/learn/misc/__ML_gmm_trainer__.py          |  81 +++++
 bob/learn/misc/cpp/ML_GMMTrainer.cpp          |  58 ++--
 bob/learn/misc/gmm_base_trainer.cpp           | 103 +++++-
 .../include/bob.learn.misc/GMMBaseTrainer.h   |  23 ++
 .../include/bob.learn.misc/ML_GMMTrainer.h    |  30 +-
 bob/learn/misc/main.cpp                       |   4 +-
 bob/learn/misc/main.h                         |  14 +
 setup.py                                      |   5 +-
 9 files changed, 571 insertions(+), 51 deletions(-)
 create mode 100644 bob/learn/misc/ML_gmm_trainer.cpp
 create mode 100644 bob/learn/misc/__ML_gmm_trainer__.py

diff --git a/bob/learn/misc/ML_gmm_trainer.cpp b/bob/learn/misc/ML_gmm_trainer.cpp
new file mode 100644
index 0000000..bddc9de
--- /dev/null
+++ b/bob/learn/misc/ML_gmm_trainer.cpp
@@ -0,0 +1,304 @@
+/**
+ * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+ * @date Web 22 Jan 16:45:00 2015
+ *
+ * @brief Python API for bob::learn::em
+ *
+ * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
+ */
+
+#include "main.h"
+
+/******************************************************************/
+/************ Constructor Section *********************************/
+/******************************************************************/
+
+static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}  /* converts PyObject to bool and returns false if object is NULL */
+
+static auto ML_GMMTrainer_doc = bob::extension::ClassDoc(
+  BOB_EXT_MODULE_PREFIX ".ML_GMMTrainer",
+  "This class implements the maximum likelihood M-step of the expectation-maximisation algorithm for a GMM Machine."
+).add_constructor(
+  bob::extension::FunctionDoc(
+    "__init__",
+    "Creates a ML_GMMTrainer",
+    "",
+    true
+  )
+  .add_prototype("gmm_base_trainer","")
+  .add_prototype("other","")
+  .add_prototype("","")
+
+  .add_parameter("gmm_base_trainer", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A set GMMBaseTrainer object.")
+  .add_parameter("other", ":py:class:`bob.learn.misc.ML_GMMTrainer`", "A ML_GMMTrainer object to be copied.")
+);
+
+
+static int PyBobLearnMiscMLGMMTrainer_init_copy(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = ML_GMMTrainer_doc.kwlist(1);
+  PyBobLearnMiscMLGMMTrainerObject* o;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscMLGMMTrainer_Type, &o)){
+    ML_GMMTrainer_doc.print_usage();
+    return -1;
+  }
+
+  self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(*o->cxx));
+  return 0;
+}
+
+
+static int PyBobLearnMiscMLGMMTrainer_init_base_trainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = ML_GMMTrainer_doc.kwlist(1);
+  PyBobLearnMiscGMMBaseTrainerObject* o;
+  
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMBaseTrainer_Type, &o)){
+    ML_GMMTrainer_doc.print_usage();
+    return -1;
+  }
+
+  self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(o->cxx));
+  return 0;
+}
+
+
+
+static int PyBobLearnMiscMLGMMTrainer_init(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
+
+  if (nargs==1){ //default initializer ()
+    //Reading the input argument
+    PyObject* arg = 0;
+    if (PyTuple_Size(args))
+      arg = PyTuple_GET_ITEM(args, 0);
+    else {
+      PyObject* tmp = PyDict_Values(kwargs);
+      auto tmp_ = make_safe(tmp);
+      arg = PyList_GET_ITEM(tmp, 0);
+    }
+
+    // If the constructor input is GMMBaseTrainer object
+    if (PyBobLearnMiscGMMBaseTrainer_Check(arg))
+      return PyBobLearnMiscMLGMMTrainer_init_base_trainer(self, args, kwargs);
+    else
+      return PyBobLearnMiscMLGMMTrainer_init_copy(self, args, kwargs);
+  }
+  else{
+    PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 1, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
+    ML_GMMTrainer_doc.print_usage();
+    return -1;  
+  }
+
+  BOB_CATCH_MEMBER("cannot create GMMBaseTrainer_init_bool", 0)
+  return 0;
+}
+
+
+static void PyBobLearnMiscMLGMMTrainer_delete(PyBobLearnMiscMLGMMTrainerObject* self) {
+  self->cxx.reset();
+  Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+
+int PyBobLearnMiscMLGMMTrainer_Check(PyObject* o) {
+  return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscMLGMMTrainer_Type));
+}
+
+
+static PyObject* PyBobLearnMiscMLGMMTrainer_RichCompare(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* other, int op) {
+  BOB_TRY
+
+  if (!PyBobLearnMiscMLGMMTrainer_Check(other)) {
+    PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name);
+    return 0;
+  }
+  auto other_ = reinterpret_cast<PyBobLearnMiscMLGMMTrainerObject*>(other);
+  switch (op) {
+    case Py_EQ:
+      if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE;
+    case Py_NE:
+      if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE;
+    default:
+      Py_INCREF(Py_NotImplemented);
+      return Py_NotImplemented;
+  }
+  BOB_CATCH_MEMBER("cannot compare ML_GMMTrainer objects", 0)
+}
+
+
+/******************************************************************/
+/************ Variables Section ***********************************/
+/******************************************************************/
+
+
+/***** gmm_base_trainer *****/
+static auto gmm_base_trainer = bob::extension::VariableDoc(
+  "gmm_base_trainer",
+  ":py:class:`bob.learn.misc.GMMBaseTrainer`",
+  "This class that implements the E-step of the expectation-maximisation algorithm.",
+  ""
+);
+PyObject* PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, void*){
+  BOB_TRY
+  
+  boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer = self->cxx->getGMMBaseTrainer();
+
+  //Allocating the correspondent python object
+  PyBobLearnMiscGMMBaseTrainerObject* retval =
+    (PyBobLearnMiscGMMBaseTrainerObject*)PyBobLearnMiscGMMBaseTrainer_Type.tp_alloc(&PyBobLearnMiscGMMBaseTrainer_Type, 0);
+
+  retval->cxx = gmm_base_trainer;
+
+  return Py_BuildValue("O",retval);
+  BOB_CATCH_MEMBER("GMMBaseTrainer could not be read", 0)
+}
+int PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* value, void*){
+  BOB_TRY
+
+  if (!PyBobLearnMiscGMMBaseTrainer_Check(value)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMBaseTrainer`", Py_TYPE(self)->tp_name, gmm_base_trainer.name());
+    return -1;
+  }
+
+  PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer = 0;
+  PyArg_Parse(value, "O!", &PyBobLearnMiscGMMBaseTrainer_Type,&gmm_base_trainer);
+
+  self->cxx->setGMMBaseTrainer(gmm_base_trainer->cxx);
+
+  return 0;
+  BOB_CATCH_MEMBER("gmm_base_trainer could not be set", -1)  
+}
+
+
+static PyGetSetDef PyBobLearnMiscMLGMMTrainer_getseters[] = { 
+  {
+    gmm_base_trainer.name(),
+    (getter)PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer,
+    (setter)PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer,
+    gmm_base_trainer.doc(),
+    0
+  },
+  {0}  // Sentinel
+};
+
+
+/******************************************************************/
+/************ Functions Section ***********************************/
+/******************************************************************/
+
+/*** initialize ***/
+static auto initialize = bob::extension::FunctionDoc(
+  "initialize",
+  "Initialization before the EM steps",
+  "",
+  true
+)
+.add_prototype("gmm_machine")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object");
+static PyObject* PyBobLearnMiscMLGMMTrainer_initialize(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = initialize.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)){
+    PyErr_Format(PyExc_RuntimeError, "%s.%s. Was not possible to read :py:class:`bob.learn.misc.GMMMachine`", Py_TYPE(self)->tp_name, initialize.name());
+    Py_RETURN_NONE;
+  }
+  self->cxx->initialize(*gmm_machine->cxx);
+  BOB_CATCH_MEMBER("cannot perform the initialize method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+
+/*** mStep ***/
+static auto mStep = bob::extension::FunctionDoc(
+  "mStep",
+  "Performs a maximum likelihood (ML) update of the GMM parameters"
+  "using the accumulated statistics in :py:class:`bob.learn.misc.GMMBaseTrainer.m_ss`",
+
+  "See Section 9.2.2 of Bishop, \"Pattern recognition and machine learning\", 2006",
+
+  true
+)
+.add_prototype("gmm_machine")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object");
+static PyObject* PyBobLearnMiscMLGMMTrainer_mStep(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = mStep.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine;
+  
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE;
+
+  self->cxx->mStep(*gmm_machine->cxx);
+
+  BOB_CATCH_MEMBER("cannot perform the mStep method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+
+static PyMethodDef PyBobLearnMiscMLGMMTrainer_methods[] = {
+  {
+    initialize.name(),
+    (PyCFunction)PyBobLearnMiscMLGMMTrainer_initialize,
+    METH_VARARGS|METH_KEYWORDS,
+    initialize.doc()
+  },
+  {
+    mStep.name(),
+    (PyCFunction)PyBobLearnMiscMLGMMTrainer_mStep,
+    METH_VARARGS|METH_KEYWORDS,
+    mStep.doc()
+  },
+  {0} /* Sentinel */
+};
+
+
+/******************************************************************/
+/************ Module Section **************************************/
+/******************************************************************/
+
+// Define the Gaussian type struct; will be initialized later
+PyTypeObject PyBobLearnMiscMLGMMTrainer_Type = {
+  PyVarObject_HEAD_INIT(0,0)
+  0
+};
+
+bool init_BobLearnMiscMLGMMTrainer(PyObject* module)
+{
+  // initialize the type struct
+  PyBobLearnMiscMLGMMTrainer_Type.tp_name      = ML_GMMTrainer_doc.name();
+  PyBobLearnMiscMLGMMTrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscMLGMMTrainerObject);
+  PyBobLearnMiscMLGMMTrainer_Type.tp_flags     = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance
+  PyBobLearnMiscMLGMMTrainer_Type.tp_doc       = ML_GMMTrainer_doc.doc();
+
+  // set the functions
+  PyBobLearnMiscMLGMMTrainer_Type.tp_new          = PyType_GenericNew;
+  PyBobLearnMiscMLGMMTrainer_Type.tp_init         = reinterpret_cast<initproc>(PyBobLearnMiscMLGMMTrainer_init);
+  PyBobLearnMiscMLGMMTrainer_Type.tp_dealloc      = reinterpret_cast<destructor>(PyBobLearnMiscMLGMMTrainer_delete);
+  PyBobLearnMiscMLGMMTrainer_Type.tp_richcompare  = reinterpret_cast<richcmpfunc>(PyBobLearnMiscMLGMMTrainer_RichCompare);
+  PyBobLearnMiscMLGMMTrainer_Type.tp_methods      = PyBobLearnMiscMLGMMTrainer_methods;
+  PyBobLearnMiscMLGMMTrainer_Type.tp_getset       = PyBobLearnMiscMLGMMTrainer_getseters;
+  //PyBobLearnMiscMLGMMTrainer_Type.tp_call         = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMLGMMTrainer_compute_likelihood);
+
+
+  // check that everything is fine
+  if (PyType_Ready(&PyBobLearnMiscMLGMMTrainer_Type) < 0) return false;
+
+  // add the type to the module
+  Py_INCREF(&PyBobLearnMiscMLGMMTrainer_Type);
+  return PyModule_AddObject(module, "_ML_GMMTrainer", (PyObject*)&PyBobLearnMiscMLGMMTrainer_Type) >= 0;
+}
+
diff --git a/bob/learn/misc/__ML_gmm_trainer__.py b/bob/learn/misc/__ML_gmm_trainer__.py
new file mode 100644
index 0000000..53c5a75
--- /dev/null
+++ b/bob/learn/misc/__ML_gmm_trainer__.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+# Mon Jan 22 18:29:10 2015
+#
+# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
+
+from ._library import _ML_GMMTrainer
+import numpy
+
+# define the class
+class ML_GMMTrainer(_ML_GMMTrainer):
+
+  def __init__(self, gmm_base_trainer, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True):
+    """
+    :py:class:bob.learn.misc.ML_GMMTrainer constructor
+
+    Keyword Parameters:
+      gmm_base_trainer
+        The base trainer (:py:class:`bob.learn.misc.GMMBaseTrainer`
+      convergence_threshold
+        Convergence threshold
+      max_iterations
+        Number of maximum iterations
+      converge_by_likelihood
+        Tells whether we compute log_likelihood as a convergence criteria, or not 
+        
+    """
+
+    _ML_GMMTrainer.__init__(self, gmm_base_trainer)
+    self.convergence_threshold  = convergence_threshold
+    self.max_iterations         = max_iterations
+    self.converge_by_likelihood = converge_by_likelihood
+
+
+  def train(self, gmm_machine, data):
+    """
+    Train the :py:class:bob.learn.misc.GMMMachine using data
+
+    Keyword Parameters:
+      gmm_machine
+        The :py:class:bob.learn.misc.GMMMachine class
+      data
+        The data to be trained
+    """
+
+    #Initialization
+    self.initialize(gmm_machine);
+
+    #Do the Expectation-Maximization algorithm
+    average_output_previous = 0
+    average_output = -numpy.inf;
+
+
+    #eStep
+    self.gmm_base_trainer.eStep(gmm_machine, data);
+
+    if(self.converge_by_likelihood):
+      average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine);    
+
+    for i in range(self.max_iterations):
+      #saves average output from last iteration
+      average_output_previous = average_output;
+
+      #mStep
+      self.mStep(gmm_machine);
+
+      #eStep
+      self.gmm_base_trainer.eStep(gmm_machine, data);
+
+      #Computes log likelihood if required
+      if(self.converge_by_likelihood):
+        average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine);
+
+        #Terminates if converged (and likelihood computation is set)
+        if abs((average_output_previous - average_output)/average_output_previous) <= self.convergence_threshold:
+          break
+
+
+# copy the documentation from the base class
+__doc__ = _ML_GMMTrainer.__doc__
diff --git a/bob/learn/misc/cpp/ML_GMMTrainer.cpp b/bob/learn/misc/cpp/ML_GMMTrainer.cpp
index ec949e0..84ecc5c 100644
--- a/bob/learn/misc/cpp/ML_GMMTrainer.cpp
+++ b/bob/learn/misc/cpp/ML_GMMTrainer.cpp
@@ -8,59 +8,54 @@
 #include <bob.learn.misc/ML_GMMTrainer.h>
 #include <algorithm>
 
-bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(const bool update_means,
-    const bool update_variances, const bool update_weights,
-    const double mean_var_update_responsibilities_threshold):
-  bob::learn::misc::GMMTrainer(update_means, update_variances, update_weights,
-    mean_var_update_responsibilities_threshold)
-{
-}
+bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer):
+  m_gmm_base_trainer(gmm_base_trainer)
+{}
+
+
 
 bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(const bob::learn::misc::ML_GMMTrainer& b):
-  bob::learn::misc::GMMTrainer(b)
-{
-}
+  m_gmm_base_trainer(b.m_gmm_base_trainer)
+{}
 
 bob::learn::misc::ML_GMMTrainer::~ML_GMMTrainer()
-{
-}
+{}
 
-void bob::learn::misc::ML_GMMTrainer::initialize(bob::learn::misc::GMMMachine& gmm,
-  const blitz::Array<double,2>& data)
+void bob::learn::misc::ML_GMMTrainer::initialize(bob::learn::misc::GMMMachine& gmm)
 {
-  bob::learn::misc::GMMTrainer::initialize(gmm, data);
+  m_gmm_base_trainer->initialize(gmm);
+  
   // Allocate cache
   size_t n_gaussians = gmm.getNGaussians();
   m_cache_ss_n_thresholded.resize(n_gaussians);
 }
 
 
-void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm,
-  const blitz::Array<double,2>& data)
+void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm)
 {
   // Read options and variables
   const size_t n_gaussians = gmm.getNGaussians();
 
   // - Update weights if requested
   //   Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006
-  if (m_update_weights) {
+  if (m_gmm_base_trainer->getUpdateWeights()) {
     blitz::Array<double,1>& weights = gmm.updateWeights();
-    weights = m_ss.n / static_cast<double>(m_ss.T); //cast req. for linux/32-bits & osx
+    weights = m_gmm_base_trainer->getGMMStats().n / static_cast<double>(m_gmm_base_trainer->getGMMStats().T); //cast req. for linux/32-bits & osx
     // Recompute the log weights in the cache of the GMMMachine
     gmm.recomputeLogWeights();
   }
 
   // Generate a thresholded version of m_ss.n
   for(size_t i=0; i<n_gaussians; ++i)
-    m_cache_ss_n_thresholded(i) = std::max(m_ss.n(i), m_mean_var_update_responsibilities_threshold);
+    m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer->getGMMStats().n(i), m_gmm_base_trainer->getMeanVarUpdateResponsibilitiesThreshold());
 
   // Update GMM parameters using the sufficient statistics (m_ss)
   // - Update means if requested
   //   Equation 9.24 of Bishop, "Pattern recognition and machine learning", 2006
-  if (m_update_means) {
+  if (m_gmm_base_trainer->getUpdateMeans()) {
     for(size_t i=0; i<n_gaussians; ++i) {
-      blitz::Array<double,1>& means = gmm.updateGaussian(i)->updateMean();
-      means = m_ss.sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
+      blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
+      means = m_gmm_base_trainer->getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
     }
   }
 
@@ -69,12 +64,12 @@ void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm,
   //   ...but we use the "computational formula for the variance", i.e.
   //   var = 1/n * sum (P(x-mean)(x-mean))
   //       = 1/n * sum (Pxx) - mean^2
-  if (m_update_variances) {
+  if (m_gmm_base_trainer->getUpdateVariances()) {
     for(size_t i=0; i<n_gaussians; ++i) {
       const blitz::Array<double,1>& means = gmm.getGaussian(i)->getMean();
-      blitz::Array<double,1>& variances = gmm.updateGaussian(i)->updateVariance();
-      variances = m_ss.sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
-      gmm.updateGaussian(i)->applyVarianceThresholds();
+      blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
+      variances = m_gmm_base_trainer->getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
+      gmm.getGaussian(i)->applyVarianceThresholds();
     }
   }
 }
@@ -84,7 +79,7 @@ bob::learn::misc::ML_GMMTrainer& bob::learn::misc::ML_GMMTrainer::operator=
 {
   if (this != &other)
   {
-    bob::learn::misc::GMMTrainer::operator=(other);
+    m_gmm_base_trainer = other.m_gmm_base_trainer;
     m_cache_ss_n_thresholded.resize(other.m_cache_ss_n_thresholded.extent(0));
   }
   return *this;
@@ -93,7 +88,7 @@ bob::learn::misc::ML_GMMTrainer& bob::learn::misc::ML_GMMTrainer::operator=
 bool bob::learn::misc::ML_GMMTrainer::operator==
   (const bob::learn::misc::ML_GMMTrainer &other) const
 {
-  return bob::learn::misc::GMMTrainer::operator==(other);
+  return m_gmm_base_trainer == other.m_gmm_base_trainer;
 }
 
 bool bob::learn::misc::ML_GMMTrainer::operator!=
@@ -102,10 +97,11 @@ bool bob::learn::misc::ML_GMMTrainer::operator!=
   return !(this->operator==(other));
 }
 
+/*
 bool bob::learn::misc::ML_GMMTrainer::is_similar_to
   (const bob::learn::misc::ML_GMMTrainer &other, const double r_epsilon,
    const double a_epsilon) const
 {
-  return bob::learn::misc::GMMTrainer::is_similar_to(other, r_epsilon, a_epsilon);
+  return m_gmm_base_trainer.is_similar_to(other, r_epsilon, a_epsilon);
 }
-
+*/
diff --git a/bob/learn/misc/gmm_base_trainer.cpp b/bob/learn/misc/gmm_base_trainer.cpp
index 5f133af..0308f16 100644
--- a/bob/learn/misc/gmm_base_trainer.cpp
+++ b/bob/learn/misc/gmm_base_trainer.cpp
@@ -8,6 +8,7 @@
  */
 
 #include "main.h"
+#include <boost/make_shared.hpp>
 
 /******************************************************************/
 /************ Constructor Section *********************************/
@@ -59,7 +60,7 @@ static int PyBobLearnMiscGMMBaseTrainer_init_bool(PyBobLearnMiscGMMBaseTrainerOb
   PyObject* update_means     = 0;
   PyObject* update_variances = 0;
   PyObject* update_weights   = 0;
-  double mean_var_update_responsibilities_threshold = 0; //std::numeric_limits<double>::epsilon();
+  double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon();
 
   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!d", kwlist, &PyBool_Type, &update_means, &PyBool_Type, 
                                                              &update_variances, &PyBool_Type, &update_weights, &mean_var_update_responsibilities_threshold)){
@@ -150,10 +151,8 @@ static auto gmm_stats = bob::extension::VariableDoc(
 PyObject* PyBobLearnMiscGMMBaseTrainer_getGMMStats(PyBobLearnMiscGMMBaseTrainerObject* self, void*){
   BOB_TRY
 
-  //bob::learn::misc::GMMStats stats = self->cxx->getGMMStats();
-  //boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = boost::make_shared(stats);
-  boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = 0;
-
+  bob::learn::misc::GMMStats stats = self->cxx->getGMMStats();
+  boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = boost::make_shared<bob::learn::misc::GMMStats>(stats);
 
   //Allocating the correspondent python object
   PyBobLearnMiscGMMStatsObject* retval =
@@ -164,6 +163,7 @@ PyObject* PyBobLearnMiscGMMBaseTrainer_getGMMStats(PyBobLearnMiscGMMBaseTrainerO
   return Py_BuildValue("O",retval);
   BOB_CATCH_MEMBER("GMMStats could not be read", 0)
 }
+/*
 int PyBobLearnMiscGMMBaseTrainer_setGMMStats(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* value, void*){
   BOB_TRY
 
@@ -180,17 +180,106 @@ int PyBobLearnMiscGMMBaseTrainer_setGMMStats(PyBobLearnMiscGMMBaseTrainerObject*
   return 0;
   BOB_CATCH_MEMBER("gmm_stats could not be set", -1)  
 }
+*/
+
+
+/***** update_means *****/
+static auto update_means = bob::extension::VariableDoc(
+  "update_means",
+  "bool",
+  "Update means on each iteration",
+  ""
+);
+PyObject* PyBobLearnMiscGMMBaseTrainer_getUpdateMeans(PyBobLearnMiscGMMBaseTrainerObject* self, void*){
+  BOB_TRY
+  return Py_BuildValue("O",self->cxx->getUpdateMeans()?Py_True:Py_False);
+  BOB_CATCH_MEMBER("update_means could not be read", 0)
+}
+
+/***** update_variances *****/
+static auto update_variances = bob::extension::VariableDoc(
+  "update_variances",
+  "bool",
+  "Update variances on each iteration",
+  ""
+);
+PyObject* PyBobLearnMiscGMMBaseTrainer_getUpdateVariances(PyBobLearnMiscGMMBaseTrainerObject* self, void*){
+  BOB_TRY
+  return Py_BuildValue("O",self->cxx->getUpdateVariances()?Py_True:Py_False);
+  BOB_CATCH_MEMBER("update_variances could not be read", 0)
+}
+
+
+/***** update_weights *****/
+static auto update_weights = bob::extension::VariableDoc(
+  "update_weights",
+  "bool",
+  "Update weights on each iteration",
+  ""
+);
+PyObject* PyBobLearnMiscGMMBaseTrainer_getUpdateWeights(PyBobLearnMiscGMMBaseTrainerObject* self, void*){
+  BOB_TRY
+  return Py_BuildValue("O",self->cxx->getUpdateWeights()?Py_True:Py_False);
+  BOB_CATCH_MEMBER("update_weights could not be read", 0)
+}
 
 
+    
+     
+
+/***** mean_var_update_responsibilities_threshold *****/
+static auto mean_var_update_responsibilities_threshold = bob::extension::VariableDoc(
+  "mean_var_update_responsibilities_threshold",
+  "bool",
+  "Threshold over the responsibilities of the Gaussians" 
+  "Equations 9.24, 9.25 of Bishop, \"Pattern recognition and machine learning\", 2006" 
+  "require a division by the responsibilities, which might be equal to zero" 
+  "because of numerical issue. This threshold is used to avoid such divisions.",
+  ""
+);
+PyObject* PyBobLearnMiscGMMBaseTrainer_getMeanVarUpdateResponsibilitiesThreshold(PyBobLearnMiscGMMBaseTrainerObject* self, void*){
+  BOB_TRY
+  return Py_BuildValue("d",self->cxx->getMeanVarUpdateResponsibilitiesThreshold());
+  BOB_CATCH_MEMBER("update_weights could not be read", 0)
+}
+
 
 static PyGetSetDef PyBobLearnMiscGMMBaseTrainer_getseters[] = { 
+  {
+    update_means.name(),
+    (getter)PyBobLearnMiscGMMBaseTrainer_getUpdateMeans,
+    0,
+    update_means.doc(),
+    0
+  },
+  {
+    update_variances.name(),
+    (getter)PyBobLearnMiscGMMBaseTrainer_getUpdateVariances,
+    0,
+    update_variances.doc(),
+    0
+  },
+  {
+    update_weights.name(),
+    (getter)PyBobLearnMiscGMMBaseTrainer_getUpdateWeights,
+    0,
+    update_weights.doc(),
+    0
+  },  
+  {
+    mean_var_update_responsibilities_threshold.name(),
+    (getter)PyBobLearnMiscGMMBaseTrainer_getMeanVarUpdateResponsibilitiesThreshold,
+    0,
+    mean_var_update_responsibilities_threshold.doc(),
+    0
+  },  
   {
     gmm_stats.name(),
     (getter)PyBobLearnMiscGMMBaseTrainer_getGMMStats,
-    (setter)PyBobLearnMiscGMMBaseTrainer_setGMMStats,
+    0, //(setter)PyBobLearnMiscGMMBaseTrainer_setGMMStats,
     gmm_stats.doc(),
     0
-  },
+  },  
   {0}  // Sentinel
 };
 
diff --git a/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
index 1cb0788..086fdc2 100644
--- a/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
+++ b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
@@ -101,14 +101,37 @@ class GMMBaseTrainer
      * E-step
      */
     void setGMMStats(const bob::learn::misc::GMMStats& stats);
+    
+    /**
+     * update means on each iteration
+     */    
+    bool getUpdateMeans()
+    {return m_update_means;}
+    
+    /**
+     * update variances on each iteration
+     */
+    bool getUpdateVariances()
+    {return m_update_variances;}
+
+
+    bool getUpdateWeights()
+    {return m_update_weights;}
+    
+    
+    double getMeanVarUpdateResponsibilitiesThreshold()
+    {return m_mean_var_update_responsibilities_threshold;}
+    
 
   private:
+  
     /**
      * These are the sufficient statistics, calculated during the
      * E-step and used during the M-step
      */
     bob::learn::misc::GMMStats m_ss;
 
+
     /**
      * update means on each iteration
      */
diff --git a/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h b/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h
index 0522f27..09b3db1 100644
--- a/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h
+++ b/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h
@@ -11,7 +11,7 @@
 #ifndef BOB_LEARN_MISC_ML_GMMTRAINER_H
 #define BOB_LEARN_MISC_ML_GMMTRAINER_H
 
-#include <bob.learn.misc/GMMTrainer.h>
+#include <bob.learn.misc/GMMBaseTrainer.h>
 #include <limits>
 
 namespace bob { namespace learn { namespace misc {
@@ -22,15 +22,12 @@ namespace bob { namespace learn { namespace misc {
  * @details See Section 9.2.2 of Bishop,
  *  "Pattern recognition and machine learning", 2006
  */
-class ML_GMMTrainer: public GMMTrainer {
+class ML_GMMTrainer{
   public:
     /**
      * @brief Default constructor
      */
-    ML_GMMTrainer(const bool update_means=true,
-      const bool update_variances=false, const bool update_weights=false,
-      const double mean_var_update_responsibilities_threshold =
-        std::numeric_limits<double>::epsilon());
+    ML_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer);
 
     /**
      * @brief Copy constructor
@@ -45,16 +42,14 @@ class ML_GMMTrainer: public GMMTrainer {
     /**
      * @brief Initialisation before the EM steps
      */
-    virtual void initialize(bob::learn::misc::GMMMachine& gmm,
-      const blitz::Array<double,2>& data);
+    virtual void initialize(bob::learn::misc::GMMMachine& gmm);
 
     /**
      * @brief Performs a maximum likelihood (ML) update of the GMM parameters
      * using the accumulated statistics in m_ss
      * Implements EMTrainer::mStep()
      */
-    virtual void mStep(bob::learn::misc::GMMMachine& gmm,
-      const blitz::Array<double,2>& data);
+    virtual void mStep(bob::learn::misc::GMMMachine& gmm);
 
     /**
      * @brief Assigns from a different ML_GMMTrainer
@@ -76,6 +71,21 @@ class ML_GMMTrainer: public GMMTrainer {
      */
     bool is_similar_to(const ML_GMMTrainer& b, const double r_epsilon=1e-5,
       const double a_epsilon=1e-8) const;
+      
+    
+    boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> getGMMBaseTrainer()
+    {return m_gmm_base_trainer;}
+    
+    void setGMMBaseTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer)
+    {m_gmm_base_trainer = gmm_base_trainer;}
+    
+
+  protected:
+
+    /**
+    Base Trainer for the MAP algorithm. Basically implements the e-step
+    */ 
+    boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> m_gmm_base_trainer;
 
 
   private:
diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp
index 0431832..404ff64 100644
--- a/bob/learn/misc/main.cpp
+++ b/bob/learn/misc/main.cpp
@@ -45,7 +45,9 @@ static PyObject* create_module (void) {
   if (!init_BobLearnMiscGMMMachine(module)) return 0;
   if (!init_BobLearnMiscKMeansMachine(module)) return 0;
   if (!init_BobLearnMiscKMeansTrainer(module)) return 0;
-  if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0;  
+  if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0;
+  if (!init_BobLearnMiscMLGMMTrainer(module)) return 0;  
+ 
 
 
   static void* PyBobLearnMisc_API[PyBobLearnMisc_API_pointers];
diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h
index cffea0e..6dd9bca 100644
--- a/bob/learn/misc/main.h
+++ b/bob/learn/misc/main.h
@@ -24,6 +24,7 @@
 #include <bob.learn.misc/KMeansTrainer.h>
 
 #include <bob.learn.misc/GMMBaseTrainer.h>
+#include <bob.learn.misc/ML_GMMTrainer.h>
 
 
 #if PY_VERSION_HEX >= 0x03000000
@@ -131,4 +132,17 @@ int PyBobLearnMiscGMMBaseTrainer_Check(PyObject* o);
 
 
 
+// ML_GMMTrainer
+typedef struct {
+  PyObject_HEAD
+  boost::shared_ptr<bob::learn::misc::ML_GMMTrainer> cxx;
+} PyBobLearnMiscMLGMMTrainerObject;
+
+extern PyTypeObject PyBobLearnMiscMLGMMTrainer_Type;
+bool init_BobLearnMiscMLGMMTrainer(PyObject* module);
+int PyBobLearnMiscMLGMMTrainer_Check(PyObject* o);
+
+
+
+
 #endif // BOB_LEARN_EM_MAIN_H
diff --git a/setup.py b/setup.py
index 9bc5e23..aab0731 100644
--- a/setup.py
+++ b/setup.py
@@ -69,8 +69,8 @@ setup(
           #"bob/learn/misc/cpp/IVectorTrainer.cpp",
           #"bob/learn/misc/cpp/JFATrainer.cpp",
           "bob/learn/misc/cpp/KMeansTrainer.cpp",
-          #"bob/learn/misc/cpp/MAP_GMMTrainer.cpp",
-          #"bob/learn/misc/cpp/ML_GMMTrainer.cpp",
+          "bob/learn/misc/cpp/MAP_GMMTrainer.cpp",
+          "bob/learn/misc/cpp/ML_GMMTrainer.cpp",
           #"bob/learn/misc/cpp/PLDATrainer.cpp",
         ],
         bob_packages = bob_packages,
@@ -106,6 +106,7 @@ setup(
           "bob/learn/misc/kmeans_machine.cpp",
           "bob/learn/misc/kmeans_trainer.cpp",
           "bob/learn/misc/gmm_base_trainer.cpp",
+          "bob/learn/misc/ML_gmm_trainer.cpp",
 
           "bob/learn/misc/main.cpp",
         ],
-- 
GitLab