From 3aad63f1699d8f14da4a8b7ad45eb415a7c3bbb8 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Sat, 7 Feb 2015 22:14:52 +0100
Subject: [PATCH] Reorganized the MAP and ML trainers

---
 bob/learn/misc/MAP_gmm_trainer.cpp            | 175 +++++++++-------
 bob/learn/misc/ML_gmm_trainer.cpp             | 187 ++++++++++--------
 bob/learn/misc/__MAP_gmm_trainer__.py         |  22 ++-
 bob/learn/misc/__ML_gmm_trainer__.py          |  20 +-
 bob/learn/misc/cpp/MAP_GMMTrainer.cpp         |  34 ++--
 bob/learn/misc/cpp/ML_GMMTrainer.cpp          |  25 ++-
 .../include/bob.learn.misc/GMMBaseTrainer.h   |  12 +-
 .../include/bob.learn.misc/MAP_GMMTrainer.h   |  49 +++--
 .../include/bob.learn.misc/ML_GMMTrainer.h    |  41 ++--
 bob/learn/misc/main.cpp                       |   2 +-
 bob/learn/misc/main.h                         |   5 +-
 bob/learn/misc/test_em.py                     |  14 +-
 setup.py                                      |   2 +-
 13 files changed, 357 insertions(+), 231 deletions(-)

diff --git a/bob/learn/misc/MAP_gmm_trainer.cpp b/bob/learn/misc/MAP_gmm_trainer.cpp
index 1ffa4d2..f3a3f4c 100644
--- a/bob/learn/misc/MAP_gmm_trainer.cpp
+++ b/bob/learn/misc/MAP_gmm_trainer.cpp
@@ -13,6 +13,7 @@
 /************ Constructor Section *********************************/
 /******************************************************************/
 
+static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}  /* converts PyObject to bool and returns false if object is NULL */
 
 static auto MAP_GMMTrainer_doc = bob::extension::ClassDoc(
   BOB_EXT_MODULE_PREFIX ".MAP_GMMTrainer",
@@ -24,18 +25,22 @@ static auto MAP_GMMTrainer_doc = bob::extension::ClassDoc(
     "",
     true
   )
-  
-  
-  .add_prototype("gmm_base_trainer,prior_gmm,relevance_factor","")
-  .add_prototype("gmm_base_trainer,prior_gmm,alpha","")
+
+  .add_prototype("prior_gmm,relevance_factor, update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
+  .add_prototype("prior_gmm,alpha, update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
   .add_prototype("other","")
   .add_prototype("","")
 
-  .add_parameter("gmm_base_trainer", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A GMMBaseTrainer object.")
   .add_parameter("prior_gmm", ":py:class:`bob.learn.misc.GMMMachine`", "The prior GMM to be adapted (Universal Backgroud Model UBM).")
   .add_parameter("reynolds_adaptation", "bool", "Will use the Reynolds adaptation procedure? See Eq (14) from [Reynolds2000]_")
   .add_parameter("relevance_factor", "double", "If set the reynolds_adaptation parameters, will apply the Reynolds Adaptation procedure. See Eq (14) from [Reynolds2000]_")
   .add_parameter("alpha", "double", "Set directly the alpha parameter (Eq (14) from [Reynolds2000]_), ignoring zeroth order statistics as a weighting factor.")
+
+  .add_parameter("update_means", "bool", "Update means on each iteration")
+  .add_parameter("update_variances", "bool", "Update variances on each iteration")
+  .add_parameter("update_weights", "bool", "Update weights on each iteration")
+  .add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.")
+
   .add_parameter("other", ":py:class:`bob.learn.misc.MAP_GMMTrainer`", "A MAP_GMMTrainer object to be copied.")
 );
 
@@ -59,43 +64,54 @@ static int PyBobLearnMiscMAPGMMTrainer_init_base_trainer(PyBobLearnMiscMAPGMMTra
   char** kwlist1 = MAP_GMMTrainer_doc.kwlist(0);
   char** kwlist2 = MAP_GMMTrainer_doc.kwlist(1);
   
-  PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer;
   PyBobLearnMiscGMMMachineObject* gmm_machine;
   bool reynolds_adaptation   = false;
   double alpha = 0.5;
   double relevance_factor = 4.0;
   double aux = 0;
 
-  PyObject* keyword_relevance_factor = Py_BuildValue("s", kwlist1[2]);
-  PyObject* keyword_alpha            = Py_BuildValue("s", kwlist2[2]);
+  PyObject* update_means     = 0;
+  PyObject* update_variances = 0;
+  PyObject* update_weights   = 0;
+  double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon();
+
+  PyObject* keyword_relevance_factor = Py_BuildValue("s", kwlist1[1]);
+  PyObject* keyword_alpha            = Py_BuildValue("s", kwlist2[1]);
 
   //Here we have to select which keyword argument to read  
-  if (kwargs && PyDict_Contains(kwargs, keyword_relevance_factor) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|d", kwlist1, 
-                                                                      &PyBobLearnMiscGMMBaseTrainer_Type, &gmm_base_trainer,
+  if (kwargs && PyDict_Contains(kwargs, keyword_relevance_factor) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!dO!|O!O!d", kwlist1, 
                                                                       &PyBobLearnMiscGMMMachine_Type, &gmm_machine,
-                                                                      &aux)))
+                                                                      &aux,
+                                                                      &PyBool_Type, &update_means, 
+                                                                      &PyBool_Type, &update_variances, 
+                                                                      &PyBool_Type, &update_weights, 
+                                                                      &mean_var_update_responsibilities_threshold)))
     reynolds_adaptation = true;    
-  else if (kwargs && PyDict_Contains(kwargs, keyword_alpha) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|d", kwlist2, 
-                                                                 &PyBobLearnMiscGMMBaseTrainer_Type, &gmm_base_trainer,
+  else if (kwargs && PyDict_Contains(kwargs, keyword_alpha) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!dO!|O!O!d", kwlist2, 
                                                                  &PyBobLearnMiscGMMMachine_Type, &gmm_machine,
-                                                                 &aux)))
+                                                                 &aux,
+                                                                 &PyBool_Type, &update_means, 
+                                                                 &PyBool_Type, &update_variances, 
+                                                                 &PyBool_Type, &update_weights, 
+                                                                 &mean_var_update_responsibilities_threshold)))
     reynolds_adaptation = false;
   else{
-    PyErr_Format(PyExc_RuntimeError, "%s. The third argument must be a keyword argument.", Py_TYPE(self)->tp_name);
+    PyErr_Format(PyExc_RuntimeError, "%s. The second argument must be a keyword argument.", Py_TYPE(self)->tp_name);
     MAP_GMMTrainer_doc.print_usage();
     return -1;
   }
 
-
-
   if (reynolds_adaptation)
     relevance_factor = aux;
   else
     alpha = aux;
   
   
-  self->cxx.reset(new bob::learn::misc::MAP_GMMTrainer(gmm_base_trainer->cxx, gmm_machine->cxx, reynolds_adaptation,relevance_factor, alpha));
+  self->cxx.reset(new bob::learn::misc::MAP_GMMTrainer(f(update_means), f(update_variances), f(update_weights), 
+                                                       mean_var_update_responsibilities_threshold, 
+                                                       reynolds_adaptation,relevance_factor, alpha, gmm_machine->cxx));
   return 0;
+
 }
 
 
@@ -151,47 +167,6 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_RichCompare(PyBobLearnMiscMAPGMMTra
 /************ Variables Section ***********************************/
 /******************************************************************/
 
-
-/***** gmm_base_trainer *****/
-static auto gmm_base_trainer = bob::extension::VariableDoc(
-  "gmm_base_trainer",
-  ":py:class:`bob.learn.misc.GMMBaseTrainer`",
-  "This class that implements the E-step of the expectation-maximisation algorithm.",
-  ""
-);
-PyObject* PyBobLearnMiscMAPGMMTrainer_getGMMBaseTrainer(PyBobLearnMiscMAPGMMTrainerObject* self, void*){
-  BOB_TRY
-  
-  boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer = self->cxx->getGMMBaseTrainer();
-
-  //Allocating the correspondent python object
-  PyBobLearnMiscGMMBaseTrainerObject* retval =
-    (PyBobLearnMiscGMMBaseTrainerObject*)PyBobLearnMiscGMMBaseTrainer_Type.tp_alloc(&PyBobLearnMiscGMMBaseTrainer_Type, 0);
-
-  retval->cxx = gmm_base_trainer;
-
-  return Py_BuildValue("O",retval);
-  BOB_CATCH_MEMBER("GMMBaseTrainer could not be read", 0)
-}
-int PyBobLearnMiscMAPGMMTrainer_setGMMBaseTrainer(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* value, void*){
-  BOB_TRY
-
-  if (!PyBobLearnMiscGMMBaseTrainer_Check(value)){
-    PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMBaseTrainer`", Py_TYPE(self)->tp_name, gmm_base_trainer.name());
-    return -1;
-  }
-
-  PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer = 0;
-  PyArg_Parse(value, "O!", &PyBobLearnMiscGMMBaseTrainer_Type,&gmm_base_trainer);
-
-  self->cxx->setGMMBaseTrainer(gmm_base_trainer->cxx);
-
-  return 0;
-  BOB_CATCH_MEMBER("gmm_base_trainer could not be set", -1)  
-}
-
-
-
 /***** relevance_factor *****/
 static auto relevance_factor = bob::extension::VariableDoc(
   "relevance_factor",
@@ -246,13 +221,6 @@ int PyBobLearnMiscMAPGMMTrainer_setAlpha(PyBobLearnMiscMAPGMMTrainerObject* self
 
 
 static PyGetSetDef PyBobLearnMiscMAPGMMTrainer_getseters[] = { 
-  {
-    gmm_base_trainer.name(),
-    (getter)PyBobLearnMiscMAPGMMTrainer_getGMMBaseTrainer,
-    (setter)PyBobLearnMiscMAPGMMTrainer_setGMMBaseTrainer,
-    gmm_base_trainer.doc(),
-    0
-  },
   {
     alpha.name(),
     (getter)PyBobLearnMiscMAPGMMTrainer_getAlpha,
@@ -304,6 +272,40 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_initialize(PyBobLearnMiscMAPGMMTrai
 }
 
 
+/*** eStep ***/
+static auto eStep = bob::extension::FunctionDoc(
+  "eStep",
+  "Calculates and saves statistics across the dataset,"
+  "and saves these as m_ss. ",
+
+  "Calculates the average log likelihood of the observations given the GMM,"
+  "and returns this in average_log_likelihood."
+  "The statistics, m_ss, will be used in the mStep() that follows.",
+
+  true
+)
+.add_prototype("gmm_machine,data")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object")
+.add_parameter("data", "array_like <float, 2D>", "Input data");
+static PyObject* PyBobLearnMiscMAPGMMTrainer_eStep(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = eStep.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine;
+  PyBlitzArrayObject* data = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine,
+                                                                 &PyBlitzArray_Converter, &data)) Py_RETURN_NONE;
+  auto data_ = make_safe(data);
+
+  self->cxx->eStep(*gmm_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data));
+
+  BOB_CATCH_MEMBER("cannot perform the eStep method", 0)
+
+  Py_RETURN_NONE;
+}
+
 
 /*** mStep ***/
 static auto mStep = bob::extension::FunctionDoc(
@@ -335,6 +337,31 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_mStep(PyBobLearnMiscMAPGMMTrainerOb
 }
 
 
+/*** computeLikelihood ***/
+static auto compute_likelihood = bob::extension::FunctionDoc(
+  "compute_likelihood",
+  "This functions returns the average min (Square Euclidean) distance (average distance to the closest mean)",
+  0,
+  true
+)
+.add_prototype("gmm_machine")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object");
+static PyObject* PyBobLearnMiscMAPGMMTrainer_compute_likelihood(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = compute_likelihood.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE;
+
+  double value = self->cxx->computeLikelihood(*gmm_machine->cxx);
+  return Py_BuildValue("d", value);
+
+  BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0)
+}
+
+
 
 static PyMethodDef PyBobLearnMiscMAPGMMTrainer_methods[] = {
   {
@@ -343,13 +370,25 @@ static PyMethodDef PyBobLearnMiscMAPGMMTrainer_methods[] = {
     METH_VARARGS|METH_KEYWORDS,
     initialize.doc()
   },
+  {
+    eStep.name(),
+    (PyCFunction)PyBobLearnMiscMAPGMMTrainer_eStep,
+    METH_VARARGS|METH_KEYWORDS,
+    eStep.doc()
+  },
   {
     mStep.name(),
     (PyCFunction)PyBobLearnMiscMAPGMMTrainer_mStep,
     METH_VARARGS|METH_KEYWORDS,
     mStep.doc()
   },
-  
+  {
+    compute_likelihood.name(),
+    (PyCFunction)PyBobLearnMiscMAPGMMTrainer_compute_likelihood,
+    METH_VARARGS|METH_KEYWORDS,
+    compute_likelihood.doc()
+  },
+
   {0} /* Sentinel */
 };
 
@@ -379,7 +418,7 @@ bool init_BobLearnMiscMAPGMMTrainer(PyObject* module)
   PyBobLearnMiscMAPGMMTrainer_Type.tp_richcompare  = reinterpret_cast<richcmpfunc>(PyBobLearnMiscMAPGMMTrainer_RichCompare);
   PyBobLearnMiscMAPGMMTrainer_Type.tp_methods      = PyBobLearnMiscMAPGMMTrainer_methods;
   PyBobLearnMiscMAPGMMTrainer_Type.tp_getset       = PyBobLearnMiscMAPGMMTrainer_getseters;
-  //PyBobLearnMiscMAPGMMTrainer_Type.tp_call         = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMAPGMMTrainer_compute_likelihood);
+  PyBobLearnMiscMAPGMMTrainer_Type.tp_call         = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMAPGMMTrainer_compute_likelihood);
 
 
   // check that everything is fine
diff --git a/bob/learn/misc/ML_gmm_trainer.cpp b/bob/learn/misc/ML_gmm_trainer.cpp
index f9a4598..ff72609 100644
--- a/bob/learn/misc/ML_gmm_trainer.cpp
+++ b/bob/learn/misc/ML_gmm_trainer.cpp
@@ -13,6 +13,8 @@
 /************ Constructor Section *********************************/
 /******************************************************************/
 
+static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}  /* converts PyObject to bool and returns false if object is NULL */
+
 static auto ML_GMMTrainer_doc = bob::extension::ClassDoc(
   BOB_EXT_MODULE_PREFIX ".ML_GMMTrainer",
   "This class implements the maximum likelihood M-step of the expectation-maximisation algorithm for a GMM Machine."
@@ -23,11 +25,16 @@ static auto ML_GMMTrainer_doc = bob::extension::ClassDoc(
     "",
     true
   )
-  .add_prototype("gmm_base_trainer","")
+  .add_prototype("update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
   .add_prototype("other","")
   .add_prototype("","")
 
-  .add_parameter("gmm_base_trainer", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A set GMMBaseTrainer object.")
+  .add_parameter("update_means", "bool", "Update means on each iteration")
+  .add_parameter("update_variances", "bool", "Update variances on each iteration")
+  .add_parameter("update_weights", "bool", "Update weights on each iteration")
+  .add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.")
+
+
   .add_parameter("other", ":py:class:`bob.learn.misc.ML_GMMTrainer`", "A ML_GMMTrainer object to be copied.")
 );
 
@@ -48,15 +55,24 @@ static int PyBobLearnMiscMLGMMTrainer_init_copy(PyBobLearnMiscMLGMMTrainerObject
 
 static int PyBobLearnMiscMLGMMTrainer_init_base_trainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
 
-  char** kwlist = ML_GMMTrainer_doc.kwlist(1);
-  PyBobLearnMiscGMMBaseTrainerObject* o;
+  char** kwlist = ML_GMMTrainer_doc.kwlist(0);
   
-  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMBaseTrainer_Type, &o)){
+  PyObject* update_means     = 0;
+  PyObject* update_variances = 0;
+  PyObject* update_weights   = 0;
+  double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon();
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!d", kwlist, 
+                                   &PyBool_Type, &update_means, 
+                                   &PyBool_Type, &update_variances, 
+                                   &PyBool_Type, &update_weights, 
+                                   &mean_var_update_responsibilities_threshold)){
     ML_GMMTrainer_doc.print_usage();
     return -1;
   }
 
-  self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(o->cxx));
+  self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(f(update_means), f(update_variances), f(update_weights), 
+                                                       mean_var_update_responsibilities_threshold));
   return 0;
 }
 
@@ -65,31 +81,24 @@ static int PyBobLearnMiscMLGMMTrainer_init_base_trainer(PyBobLearnMiscMLGMMTrain
 static int PyBobLearnMiscMLGMMTrainer_init(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
   BOB_TRY
 
-  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
-
-  if (nargs==1){ //default initializer ()
-    //Reading the input argument
-    PyObject* arg = 0;
-    if (PyTuple_Size(args))
-      arg = PyTuple_GET_ITEM(args, 0);
-    else {
-      PyObject* tmp = PyDict_Values(kwargs);
-      auto tmp_ = make_safe(tmp);
-      arg = PyList_GET_ITEM(tmp, 0);
-    }
-
-    // If the constructor input is GMMBaseTrainer object
-    if (PyBobLearnMiscGMMBaseTrainer_Check(arg))
-      return PyBobLearnMiscMLGMMTrainer_init_base_trainer(self, args, kwargs);
-    else
-      return PyBobLearnMiscMLGMMTrainer_init_copy(self, args, kwargs);
-  }
-  else{
-    PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 1, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
-    ML_GMMTrainer_doc.print_usage();
-    return -1;  
+  //Reading the input argument
+  PyObject* arg = 0;
+  if (PyTuple_Size(args))
+    arg = PyTuple_GET_ITEM(args, 0);
+  else {
+    PyObject* tmp = PyDict_Values(kwargs);
+    auto tmp_ = make_safe(tmp);
+    arg = PyList_GET_ITEM(tmp, 0);
   }
 
+  // If the constructor input is GMMBaseTrainer object
+  if (PyBobLearnMiscMLGMMTrainer_Check(arg))
+    return PyBobLearnMiscMLGMMTrainer_init_copy(self, args, kwargs);
+  else
+    return PyBobLearnMiscMLGMMTrainer_init_base_trainer(self, args, kwargs);
+
+
+
   BOB_CATCH_MEMBER("cannot create GMMBaseTrainer_init_bool", 0)
   return 0;
 }
@@ -131,54 +140,7 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_RichCompare(PyBobLearnMiscMLGMMTrain
 /************ Variables Section ***********************************/
 /******************************************************************/
 
-
-/***** gmm_base_trainer *****/
-static auto gmm_base_trainer = bob::extension::VariableDoc(
-  "gmm_base_trainer",
-  ":py:class:`bob.learn.misc.GMMBaseTrainer`",
-  "This class that implements the E-step of the expectation-maximisation algorithm.",
-  ""
-);
-PyObject* PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, void*){
-  BOB_TRY
-  
-  boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer = self->cxx->getGMMBaseTrainer();
-
-  //Allocating the correspondent python object
-  PyBobLearnMiscGMMBaseTrainerObject* retval =
-    (PyBobLearnMiscGMMBaseTrainerObject*)PyBobLearnMiscGMMBaseTrainer_Type.tp_alloc(&PyBobLearnMiscGMMBaseTrainer_Type, 0);
-
-  retval->cxx = gmm_base_trainer;
-
-  return Py_BuildValue("O",retval);
-  BOB_CATCH_MEMBER("GMMBaseTrainer could not be read", 0)
-}
-int PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* value, void*){
-  BOB_TRY
-
-  if (!PyBobLearnMiscGMMBaseTrainer_Check(value)){
-    PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMBaseTrainer`", Py_TYPE(self)->tp_name, gmm_base_trainer.name());
-    return -1;
-  }
-
-  PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer = 0;
-  PyArg_Parse(value, "O!", &PyBobLearnMiscGMMBaseTrainer_Type,&gmm_base_trainer);
-
-  self->cxx->setGMMBaseTrainer(gmm_base_trainer->cxx);
-
-  return 0;
-  BOB_CATCH_MEMBER("gmm_base_trainer could not be set", -1)  
-}
-
-
 static PyGetSetDef PyBobLearnMiscMLGMMTrainer_getseters[] = { 
-  {
-    gmm_base_trainer.name(),
-    (getter)PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer,
-    (setter)PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer,
-    gmm_base_trainer.doc(),
-    0
-  },
   {0}  // Sentinel
 };
 
@@ -215,6 +177,40 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_initialize(PyBobLearnMiscMLGMMTraine
 }
 
 
+/*** eStep ***/
+static auto eStep = bob::extension::FunctionDoc(
+  "eStep",
+  "Calculates and saves statistics across the dataset,"
+  "and saves these as m_ss. ",
+
+  "Calculates the average log likelihood of the observations given the GMM,"
+  "and returns this in average_log_likelihood."
+  "The statistics, m_ss, will be used in the mStep() that follows.",
+
+  true
+)
+.add_prototype("gmm_machine,data")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object")
+.add_parameter("data", "array_like <float, 2D>", "Input data");
+static PyObject* PyBobLearnMiscMLGMMTrainer_eStep(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = eStep.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine;
+  PyBlitzArrayObject* data = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine,
+                                                                 &PyBlitzArray_Converter, &data)) Py_RETURN_NONE;
+  auto data_ = make_safe(data);
+
+  self->cxx->eStep(*gmm_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data));
+
+  BOB_CATCH_MEMBER("cannot perform the eStep method", 0)
+
+  Py_RETURN_NONE;
+}
+
 
 /*** mStep ***/
 static auto mStep = bob::extension::FunctionDoc(
@@ -246,6 +242,31 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_mStep(PyBobLearnMiscMLGMMTrainerObje
 }
 
 
+/*** computeLikelihood ***/
+static auto compute_likelihood = bob::extension::FunctionDoc(
+  "compute_likelihood",
+  "This functions returns the average min (Square Euclidean) distance (average distance to the closest mean)",
+  0,
+  true
+)
+.add_prototype("gmm_machine")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object");
+static PyObject* PyBobLearnMiscMLGMMTrainer_compute_likelihood(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = compute_likelihood.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE;
+
+  double value = self->cxx->computeLikelihood(*gmm_machine->cxx);
+  return Py_BuildValue("d", value);
+
+  BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0)
+}
+
+
 
 static PyMethodDef PyBobLearnMiscMLGMMTrainer_methods[] = {
   {
@@ -254,12 +275,24 @@ static PyMethodDef PyBobLearnMiscMLGMMTrainer_methods[] = {
     METH_VARARGS|METH_KEYWORDS,
     initialize.doc()
   },
+  {
+    eStep.name(),
+    (PyCFunction)PyBobLearnMiscMLGMMTrainer_eStep,
+    METH_VARARGS|METH_KEYWORDS,
+    eStep.doc()
+  },
   {
     mStep.name(),
     (PyCFunction)PyBobLearnMiscMLGMMTrainer_mStep,
     METH_VARARGS|METH_KEYWORDS,
     mStep.doc()
   },
+  {
+    compute_likelihood.name(),
+    (PyCFunction)PyBobLearnMiscMLGMMTrainer_compute_likelihood,
+    METH_VARARGS|METH_KEYWORDS,
+    compute_likelihood.doc()
+  },
   {0} /* Sentinel */
 };
 
@@ -289,7 +322,7 @@ bool init_BobLearnMiscMLGMMTrainer(PyObject* module)
   PyBobLearnMiscMLGMMTrainer_Type.tp_richcompare  = reinterpret_cast<richcmpfunc>(PyBobLearnMiscMLGMMTrainer_RichCompare);
   PyBobLearnMiscMLGMMTrainer_Type.tp_methods      = PyBobLearnMiscMLGMMTrainer_methods;
   PyBobLearnMiscMLGMMTrainer_Type.tp_getset       = PyBobLearnMiscMLGMMTrainer_getseters;
-  //PyBobLearnMiscMLGMMTrainer_Type.tp_call         = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMLGMMTrainer_compute_likelihood);
+  PyBobLearnMiscMLGMMTrainer_Type.tp_call         = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMLGMMTrainer_compute_likelihood);
 
 
   // check that everything is fine
diff --git a/bob/learn/misc/__MAP_gmm_trainer__.py b/bob/learn/misc/__MAP_gmm_trainer__.py
index 18f215d..4258b08 100644
--- a/bob/learn/misc/__MAP_gmm_trainer__.py
+++ b/bob/learn/misc/__MAP_gmm_trainer__.py
@@ -11,13 +11,17 @@ import numpy
 # define the class
 class MAP_GMMTrainer(_MAP_GMMTrainer):
 
-  def __init__(self, gmm_base_trainer, prior_gmm, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True, **kwargs):
+  def __init__(self, prior_gmm, update_means=True, update_variances=False, update_weights=False, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True, **kwargs):
     """
     :py:class:`bob.learn.misc.MAP_GMMTrainer` constructor
 
     Keyword Parameters:
-      gmm_base_trainer
-        The base trainer (:py:class:`bob.learn.misc.GMMBaseTrainer`)
+      update_means
+
+      update_variances
+
+      update_weights
+
       prior_gmm
         A :py:class:`bob.learn.misc.GMMMachine` to be adapted
       convergence_threshold
@@ -34,10 +38,10 @@ class MAP_GMMTrainer(_MAP_GMMTrainer):
 
     if kwargs.get('alpha')!=None:
       alpha = kwargs.get('alpha')
-      _MAP_GMMTrainer.__init__(self, gmm_base_trainer, prior_gmm, alpha=alpha)
+      _MAP_GMMTrainer.__init__(self, prior_gmm,alpha=alpha, update_means=update_means, update_variances=update_variances,update_weights=update_weights)
     else:
       relevance_factor = kwargs.get('relevance_factor')
-      _MAP_GMMTrainer.__init__(self, gmm_base_trainer, prior_gmm, relevance_factor=relevance_factor)
+      _MAP_GMMTrainer.__init__(self, prior_gmm, relevance_factor=relevance_factor, update_means=update_means, update_variances=update_variances,update_weights=update_weights)
     
     self.convergence_threshold  = convergence_threshold
     self.max_iterations         = max_iterations
@@ -67,10 +71,10 @@ class MAP_GMMTrainer(_MAP_GMMTrainer):
 
 
     #eStep
-    self.gmm_base_trainer.eStep(gmm_machine, data);
+    self.eStep(gmm_machine, data);
 
     if(self.converge_by_likelihood):
-      average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine);    
+      average_output = self.compute_likelihood(gmm_machine);    
 
     for i in range(self.max_iterations):
       #saves average output from last iteration
@@ -80,11 +84,11 @@ class MAP_GMMTrainer(_MAP_GMMTrainer):
       self.mStep(gmm_machine);
 
       #eStep
-      self.gmm_base_trainer.eStep(gmm_machine, data);
+      self.eStep(gmm_machine, data);
 
       #Computes log likelihood if required
       if(self.converge_by_likelihood):
-        average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine);
+        average_output = self.compute_likelihood(gmm_machine);
 
         #Terminates if converged (and likelihood computation is set)
         if abs((average_output_previous - average_output)/average_output_previous) <= self.convergence_threshold:
diff --git a/bob/learn/misc/__ML_gmm_trainer__.py b/bob/learn/misc/__ML_gmm_trainer__.py
index 53c5a75..93a3c6c 100644
--- a/bob/learn/misc/__ML_gmm_trainer__.py
+++ b/bob/learn/misc/__ML_gmm_trainer__.py
@@ -11,13 +11,17 @@ import numpy
 # define the class
 class ML_GMMTrainer(_ML_GMMTrainer):
 
-  def __init__(self, gmm_base_trainer, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True):
+  def __init__(self, update_means=True, update_variances=False, update_weights=False, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True):
     """
     :py:class:bob.learn.misc.ML_GMMTrainer constructor
 
     Keyword Parameters:
-      gmm_base_trainer
-        The base trainer (:py:class:`bob.learn.misc.GMMBaseTrainer`
+      update_means
+
+      update_variances
+
+      update_weights
+ 
       convergence_threshold
         Convergence threshold
       max_iterations
@@ -27,7 +31,7 @@ class ML_GMMTrainer(_ML_GMMTrainer):
         
     """
 
-    _ML_GMMTrainer.__init__(self, gmm_base_trainer)
+    _ML_GMMTrainer.__init__(self, update_means=update_means, update_variances=update_variances, update_weights=update_weights)
     self.convergence_threshold  = convergence_threshold
     self.max_iterations         = max_iterations
     self.converge_by_likelihood = converge_by_likelihood
@@ -53,10 +57,10 @@ class ML_GMMTrainer(_ML_GMMTrainer):
 
 
     #eStep
-    self.gmm_base_trainer.eStep(gmm_machine, data);
+    self.eStep(gmm_machine, data);
 
     if(self.converge_by_likelihood):
-      average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine);    
+      average_output = self.compute_likelihood(gmm_machine);    
 
     for i in range(self.max_iterations):
       #saves average output from last iteration
@@ -66,11 +70,11 @@ class ML_GMMTrainer(_ML_GMMTrainer):
       self.mStep(gmm_machine);
 
       #eStep
-      self.gmm_base_trainer.eStep(gmm_machine, data);
+      self.eStep(gmm_machine, data);
 
       #Computes log likelihood if required
       if(self.converge_by_likelihood):
-        average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine);
+        average_output = self.compute_likelihood(gmm_machine);
 
         #Terminates if converged (and likelihood computation is set)
         if abs((average_output_previous - average_output)/average_output_previous) <= self.convergence_threshold:
diff --git a/bob/learn/misc/cpp/MAP_GMMTrainer.cpp b/bob/learn/misc/cpp/MAP_GMMTrainer.cpp
index 0645c2c..d20b150 100644
--- a/bob/learn/misc/cpp/MAP_GMMTrainer.cpp
+++ b/bob/learn/misc/cpp/MAP_GMMTrainer.cpp
@@ -8,8 +8,18 @@
 #include <bob.learn.misc/MAP_GMMTrainer.h>
 #include <bob.core/check.h>
 
-bob::learn::misc::MAP_GMMTrainer::MAP_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer, boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm, const bool reynolds_adaptation, const double relevance_factor, const double alpha):
-  m_gmm_base_trainer(gmm_base_trainer),
+bob::learn::misc::MAP_GMMTrainer::MAP_GMMTrainer(
+   const bool update_means,
+   const bool update_variances,
+   const bool update_weights,
+   const double mean_var_update_responsibilities_threshold,
+
+   const bool reynolds_adaptation, 
+   const double relevance_factor, 
+   const double alpha,
+   boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm):
+
+  m_gmm_base_trainer(update_means, update_variances, update_weights, mean_var_update_responsibilities_threshold),
   m_prior_gmm(prior_gmm)
 {
   m_reynolds_adaptation = reynolds_adaptation;
@@ -37,7 +47,7 @@ void bob::learn::misc::MAP_GMMTrainer::initialize(bob::learn::misc::GMMMachine&
     throw std::runtime_error("MAP_GMMTrainer: Prior GMM distribution has not been set");
 
   // Allocate memory for the sufficient statistics and initialise
-  m_gmm_base_trainer->initialize(gmm);
+  m_gmm_base_trainer.initialize(gmm);
 
   const size_t n_gaussians = gmm.getNGaussians();
   // TODO: check size?
@@ -78,13 +88,13 @@ void bob::learn::misc::MAP_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm)
   if (!m_reynolds_adaptation)
     m_cache_alpha = m_alpha;
   else
-    m_cache_alpha = m_gmm_base_trainer->getGMMStats().n(i) / (m_gmm_base_trainer->getGMMStats().n(i) + m_relevance_factor);
+    m_cache_alpha = m_gmm_base_trainer.getGMMStats().n(i) / (m_gmm_base_trainer.getGMMStats().n(i) + m_relevance_factor);
 
   // - Update weights if requested
   //   Equation 11 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000
-  if (m_gmm_base_trainer->getUpdateWeights()) {
+  if (m_gmm_base_trainer.getUpdateWeights()) {
     // Calculate the maximum likelihood weights
-    m_cache_ml_weights = m_gmm_base_trainer->getGMMStats().n / static_cast<double>(m_gmm_base_trainer->getGMMStats().T); //cast req. for linux/32-bits & osx
+    m_cache_ml_weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx
 
     // Get the prior weights
     const blitz::Array<double,1>& prior_weights = m_prior_gmm->getWeights();
@@ -104,35 +114,35 @@ void bob::learn::misc::MAP_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm)
   // Update GMM parameters
   // - Update means if requested
   //   Equation 12 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000
-  if (m_gmm_base_trainer->getUpdateMeans()) {
+  if (m_gmm_base_trainer.getUpdateMeans()) {
     // Calculate new means
     for (size_t i=0; i<n_gaussians; ++i) {
       const blitz::Array<double,1>& prior_means = m_prior_gmm->getGaussian(i)->getMean();
       blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
-      if (m_gmm_base_trainer->getGMMStats().n(i) < m_gmm_base_trainer->getMeanVarUpdateResponsibilitiesThreshold()) {
+      if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
         means = prior_means;
       }
       else {
         // Use the maximum likelihood means
-        means = m_cache_alpha(i) * (m_gmm_base_trainer->getGMMStats().sumPx(i,blitz::Range::all()) / m_gmm_base_trainer->getGMMStats().n(i)) + (1-m_cache_alpha(i)) * prior_means;
+        means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats().sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i)) + (1-m_cache_alpha(i)) * prior_means;
       }
     }
   }
 
   // - Update variance if requested
   //   Equation 13 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000
-  if (m_gmm_base_trainer->getUpdateVariances()) {
+  if (m_gmm_base_trainer.getUpdateVariances()) {
     // Calculate new variances (equation 13)
     for (size_t i=0; i<n_gaussians; ++i) {
       const blitz::Array<double,1>& prior_means = m_prior_gmm->getGaussian(i)->getMean();
       blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
       const blitz::Array<double,1>& prior_variances = m_prior_gmm->getGaussian(i)->getVariance();
       blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
-      if (m_gmm_base_trainer->getGMMStats().n(i) < m_gmm_base_trainer->getMeanVarUpdateResponsibilitiesThreshold()) {
+      if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
         variances = (prior_variances + prior_means) - blitz::pow2(means);
       }
       else {
-        variances = m_cache_alpha(i) * m_gmm_base_trainer->getGMMStats().sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer->getGMMStats().n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means);
+        variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats().sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means);
       }
       gmm.getGaussian(i)->applyVarianceThresholds();
     }
diff --git a/bob/learn/misc/cpp/ML_GMMTrainer.cpp b/bob/learn/misc/cpp/ML_GMMTrainer.cpp
index 84ecc5c..f08fb2f 100644
--- a/bob/learn/misc/cpp/ML_GMMTrainer.cpp
+++ b/bob/learn/misc/cpp/ML_GMMTrainer.cpp
@@ -8,8 +8,13 @@
 #include <bob.learn.misc/ML_GMMTrainer.h>
 #include <algorithm>
 
-bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer):
-  m_gmm_base_trainer(gmm_base_trainer)
+bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(
+   const bool update_means,
+   const bool update_variances, 
+   const bool update_weights,
+   const double mean_var_update_responsibilities_threshold
+):
+  m_gmm_base_trainer(update_means, update_variances, update_weights, mean_var_update_responsibilities_threshold)
 {}
 
 
@@ -23,7 +28,7 @@ bob::learn::misc::ML_GMMTrainer::~ML_GMMTrainer()
 
 void bob::learn::misc::ML_GMMTrainer::initialize(bob::learn::misc::GMMMachine& gmm)
 {
-  m_gmm_base_trainer->initialize(gmm);
+  m_gmm_base_trainer.initialize(gmm);
   
   // Allocate cache
   size_t n_gaussians = gmm.getNGaussians();
@@ -38,24 +43,24 @@ void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm)
 
   // - Update weights if requested
   //   Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006
-  if (m_gmm_base_trainer->getUpdateWeights()) {
+  if (m_gmm_base_trainer.getUpdateWeights()) {
     blitz::Array<double,1>& weights = gmm.updateWeights();
-    weights = m_gmm_base_trainer->getGMMStats().n / static_cast<double>(m_gmm_base_trainer->getGMMStats().T); //cast req. for linux/32-bits & osx
+    weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx
     // Recompute the log weights in the cache of the GMMMachine
     gmm.recomputeLogWeights();
   }
 
   // Generate a thresholded version of m_ss.n
   for(size_t i=0; i<n_gaussians; ++i)
-    m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer->getGMMStats().n(i), m_gmm_base_trainer->getMeanVarUpdateResponsibilitiesThreshold());
+    m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats().n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold());
 
   // Update GMM parameters using the sufficient statistics (m_ss)
   // - Update means if requested
   //   Equation 9.24 of Bishop, "Pattern recognition and machine learning", 2006
-  if (m_gmm_base_trainer->getUpdateMeans()) {
+  if (m_gmm_base_trainer.getUpdateMeans()) {
     for(size_t i=0; i<n_gaussians; ++i) {
       blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
-      means = m_gmm_base_trainer->getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
+      means = m_gmm_base_trainer.getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
     }
   }
 
@@ -64,11 +69,11 @@ void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm)
   //   ...but we use the "computational formula for the variance", i.e.
   //   var = 1/n * sum (P(x-mean)(x-mean))
   //       = 1/n * sum (Pxx) - mean^2
-  if (m_gmm_base_trainer->getUpdateVariances()) {
+  if (m_gmm_base_trainer.getUpdateVariances()) {
     for(size_t i=0; i<n_gaussians; ++i) {
       const blitz::Array<double,1>& means = gmm.getGaussian(i)->getMean();
       blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
-      variances = m_gmm_base_trainer->getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
+      variances = m_gmm_base_trainer.getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
       gmm.getGaussian(i)->applyVarianceThresholds();
     }
   }
diff --git a/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
index 086fdc2..4015055 100644
--- a/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
+++ b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
@@ -30,9 +30,9 @@ class GMMBaseTrainer
      * @brief Default constructor
      */
     GMMBaseTrainer(const bool update_means=true,
-      const bool update_variances=false, const bool update_weights=false,
-      const double mean_var_update_responsibilities_threshold =
-        std::numeric_limits<double>::epsilon());
+                   const bool update_variances=false, 
+                   const bool update_weights=false,
+                   const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon());
 
     /**
      * @brief Copy constructor
@@ -47,7 +47,7 @@ class GMMBaseTrainer
     /**
      * @brief Initialization before the EM steps
      */
-    virtual void initialize(bob::learn::misc::GMMMachine& gmm);
+    void initialize(bob::learn::misc::GMMMachine& gmm);
 
     /**
      * @brief Calculates and saves statistics across the dataset,
@@ -58,14 +58,14 @@ class GMMBaseTrainer
      * The statistics, m_ss, will be used in the mStep() that follows.
      * Implements EMTrainer::eStep(double &)
      */
-    virtual void eStep(bob::learn::misc::GMMMachine& gmm,
+     void eStep(bob::learn::misc::GMMMachine& gmm,
       const blitz::Array<double,2>& data);
 
     /**
      * @brief Computes the likelihood using current estimates of the latent
      * variables
      */
-    virtual double computeLikelihood(bob::learn::misc::GMMMachine& gmm);
+    double computeLikelihood(bob::learn::misc::GMMMachine& gmm);
 
 
     /**
diff --git a/bob/learn/misc/include/bob.learn.misc/MAP_GMMTrainer.h b/bob/learn/misc/include/bob.learn.misc/MAP_GMMTrainer.h
index 42e4552..c6c7cf7 100644
--- a/bob/learn/misc/include/bob.learn.misc/MAP_GMMTrainer.h
+++ b/bob/learn/misc/include/bob.learn.misc/MAP_GMMTrainer.h
@@ -26,7 +26,15 @@ class MAP_GMMTrainer
     /**
      * @brief Default constructor
      */
-    MAP_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer, boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm, const bool reynolds_adaptation=false, const double relevance_factor=4, const double alpha=0.5);
+    MAP_GMMTrainer(
+      const bool update_means=true,
+      const bool update_variances=false, 
+      const bool update_weights=false,
+      const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(),
+      const bool reynolds_adaptation=false, 
+      const double relevance_factor=4, 
+      const double alpha=0.5,
+      boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm = 0);
 
     /**
      * @brief Copy constructor
@@ -41,7 +49,7 @@ class MAP_GMMTrainer
     /**
      * @brief Initialization
      */
-    virtual void initialize(bob::learn::misc::GMMMachine& gmm);
+    void initialize(bob::learn::misc::GMMMachine& gmm);
 
     /**
      * @brief Assigns from a different MAP_GMMTrainer
@@ -71,6 +79,21 @@ class MAP_GMMTrainer
      */
     bool setPriorGMM(boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm);
 
+    /**
+     * @brief Calculates and saves statistics across the dataset,
+     * and saves these as m_ss. Calculates the average
+     * log likelihood of the observations given the GMM,
+     * and returns this in average_log_likelihood.
+     *
+     * The statistics, m_ss, will be used in the mStep() that follows.
+     * Implements EMTrainer::eStep(double &)
+     */
+     void eStep(bob::learn::misc::GMMMachine& gmm,
+      const blitz::Array<double,2>& data){
+      m_gmm_base_trainer.eStep(gmm,data);
+     }
+
+
     /**
      * @brief Performs a maximum a posteriori (MAP) update of the GMM
      * parameters using the accumulated statistics in m_ss and the
@@ -80,13 +103,12 @@ class MAP_GMMTrainer
     void mStep(bob::learn::misc::GMMMachine& gmm);
 
     /**
-     * @brief Use a Torch3-like adaptation rule rather than Reynolds'one
-     * In this case, alpha is a configuration variable rather than a function of the zeroth
-     * order statistics and a relevance factor (should be in range [0,1])
+     * @brief Computes the likelihood using current estimates of the latent
+     * variables
      */
-    //void setT3MAP(const double alpha) { m_T3_adaptation = true; m_T3_alpha = alpha; }
-    //void unsetT3MAP() { m_T3_adaptation = false; }
-    
+    double computeLikelihood(bob::learn::misc::GMMMachine& gmm){
+      return m_gmm_base_trainer.computeLikelihood(gmm);
+    }    
     
     bool getReynoldsAdaptation()
     {return m_reynolds_adaptation;}
@@ -109,13 +131,6 @@ class MAP_GMMTrainer
     {m_alpha = alpha;}
 
 
-    boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> getGMMBaseTrainer()
-    {return m_gmm_base_trainer;}
-    
-    void setGMMBaseTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer)
-    {m_gmm_base_trainer = gmm_base_trainer;}
-    
-
   protected:
 
     /**
@@ -123,12 +138,10 @@ class MAP_GMMTrainer
      */
     double m_relevance_factor;
 
-
     /**
     Base Trainer for the MAP algorithm. Basically implements the e-step
     */ 
-    boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> m_gmm_base_trainer;
-
+    bob::learn::misc::GMMBaseTrainer m_gmm_base_trainer;
 
     /**
      * The GMM to use as a prior for MAP adaptation.
diff --git a/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h b/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h
index 09b3db1..13cda74 100644
--- a/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h
+++ b/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h
@@ -27,7 +27,10 @@ class ML_GMMTrainer{
     /**
      * @brief Default constructor
      */
-    ML_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer);
+    ML_GMMTrainer(const bool update_means=true,
+                  const bool update_variances=false, 
+                  const bool update_weights=false,
+                  const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon());
 
     /**
      * @brief Copy constructor
@@ -42,14 +45,37 @@ class ML_GMMTrainer{
     /**
      * @brief Initialisation before the EM steps
      */
-    virtual void initialize(bob::learn::misc::GMMMachine& gmm);
+    void initialize(bob::learn::misc::GMMMachine& gmm);
+
+    /**
+     * @brief Calculates and saves statistics across the dataset,
+     * and saves these as m_ss. Calculates the average
+     * log likelihood of the observations given the GMM,
+     * and returns this in average_log_likelihood.
+     *
+     * The statistics, m_ss, will be used in the mStep() that follows.
+     * Implements EMTrainer::eStep(double &)
+     */
+     void eStep(bob::learn::misc::GMMMachine& gmm,
+      const blitz::Array<double,2>& data){
+      m_gmm_base_trainer.eStep(gmm,data);
+     }
 
     /**
      * @brief Performs a maximum likelihood (ML) update of the GMM parameters
      * using the accumulated statistics in m_ss
      * Implements EMTrainer::mStep()
      */
-    virtual void mStep(bob::learn::misc::GMMMachine& gmm);
+    void mStep(bob::learn::misc::GMMMachine& gmm);
+
+    /**
+     * @brief Computes the likelihood using current estimates of the latent
+     * variables
+     */
+    double computeLikelihood(bob::learn::misc::GMMMachine& gmm){
+      return m_gmm_base_trainer.computeLikelihood(gmm);
+    }
+
 
     /**
      * @brief Assigns from a different ML_GMMTrainer
@@ -73,19 +99,12 @@ class ML_GMMTrainer{
       const double a_epsilon=1e-8) const;
       
     
-    boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> getGMMBaseTrainer()
-    {return m_gmm_base_trainer;}
-    
-    void setGMMBaseTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer)
-    {m_gmm_base_trainer = gmm_base_trainer;}
-    
-
   protected:
 
     /**
     Base Trainer for the MAP algorithm. Basically implements the e-step
     */ 
-    boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> m_gmm_base_trainer;
+    bob::learn::misc::GMMBaseTrainer m_gmm_base_trainer;
 
 
   private:
diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp
index 9e6dc2c..10e1e8a 100644
--- a/bob/learn/misc/main.cpp
+++ b/bob/learn/misc/main.cpp
@@ -75,7 +75,7 @@ static PyObject* create_module (void) {
   if (!init_BobLearnMiscGMMMachine(module)) return 0;
   if (!init_BobLearnMiscKMeansMachine(module)) return 0;
   if (!init_BobLearnMiscKMeansTrainer(module)) return 0;
-  if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0;
+  //if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0;
   if (!init_BobLearnMiscMLGMMTrainer(module)) return 0;  
   if (!init_BobLearnMiscMAPGMMTrainer(module)) return 0;
 
diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h
index 42f5f12..5be119c 100644
--- a/bob/learn/misc/main.h
+++ b/bob/learn/misc/main.h
@@ -26,7 +26,7 @@
 #include <bob.learn.misc/KMeansMachine.h>
 
 #include <bob.learn.misc/KMeansTrainer.h>
-#include <bob.learn.misc/GMMBaseTrainer.h>
+//#include <bob.learn.misc/GMMBaseTrainer.h>
 #include <bob.learn.misc/ML_GMMTrainer.h>
 #include <bob.learn.misc/MAP_GMMTrainer.h>
 
@@ -145,6 +145,7 @@ int PyBobLearnMiscKMeansTrainer_Check(PyObject* o);
 
 
 // GMMBaseTrainer
+/*
 typedef struct {
   PyObject_HEAD
   boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> cxx;
@@ -153,7 +154,7 @@ typedef struct {
 extern PyTypeObject PyBobLearnMiscGMMBaseTrainer_Type;
 bool init_BobLearnMiscGMMBaseTrainer(PyObject* module);
 int PyBobLearnMiscGMMBaseTrainer_Check(PyObject* o);
-
+*/
 
 // ML_GMMTrainer
 typedef struct {
diff --git a/bob/learn/misc/test_em.py b/bob/learn/misc/test_em.py
index dffb86a..88070de 100644
--- a/bob/learn/misc/test_em.py
+++ b/bob/learn/misc/test_em.py
@@ -14,7 +14,7 @@ import bob.io.base
 from bob.io.base.test_utils import datafile
 
 from . import KMeansMachine, GMMMachine, KMeansTrainer, \
-    GMMBaseTrainer, ML_GMMTrainer, MAP_GMMTrainer
+    ML_GMMTrainer, MAP_GMMTrainer
 
 #, MAP_GMMTrainer
 
@@ -50,7 +50,7 @@ def test_gmm_ML_1():
   ar = bob.io.base.load(datafile("faithful.torch3_f64.hdf5", __name__))  
   gmm = loadGMM()
   
-  ml_gmmtrainer = ML_GMMTrainer(GMMBaseTrainer(True, True, True))
+  ml_gmmtrainer = ML_GMMTrainer(True, True, True)
   ml_gmmtrainer.train(gmm, ar)
 
   #config = bob.io.base.HDF5File(datafile('gmm_ML.hdf5", __name__), 'w')
@@ -82,7 +82,7 @@ def test_gmm_ML_2():
   prior = 0.001
   max_iter_gmm = 25
   accuracy = 0.00001
-  ml_gmmtrainer = ML_GMMTrainer(GMMBaseTrainer(True, True, True, prior), converge_by_likelihood=True)
+  ml_gmmtrainer = ML_GMMTrainer(True, True, True, prior, converge_by_likelihood=True)
   ml_gmmtrainer.max_iterations = max_iter_gmm
   ml_gmmtrainer.convergence_threshold = accuracy
   
@@ -97,8 +97,6 @@ def test_gmm_ML_2():
   weightsML_ref = bob.io.base.load(datafile('weightsAfterML.hdf5', __name__))
 
 
-  print sum(sum(gmm.means - meansML_ref))
-
   # Compare to current results
   assert equals(gmm.means, meansML_ref, 3e-3)
   assert equals(gmm.variances, variancesML_ref, 3e-3)
@@ -115,7 +113,7 @@ def test_gmm_MAP_1():
   gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__)))
   gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__)))
 
-  map_gmmtrainer = MAP_GMMTrainer(GMMBaseTrainer(True, False, False),gmmprior, relevance_factor=4.)  
+  map_gmmtrainer = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, prior_gmm=gmmprior, relevance_factor=4.)  
   #map_gmmtrainer.set_prior_gmm(gmmprior)
   map_gmmtrainer.train(gmm, ar)
 
@@ -142,7 +140,7 @@ def test_gmm_MAP_2():
   gmm.variances = variances
   gmm.weights = weights
 
-  map_adapt = MAP_GMMTrainer(GMMBaseTrainer(True, False, False, mean_var_update_responsibilities_threshold=0.),gmm, relevance_factor=4.)
+  map_adapt = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, mean_var_update_responsibilities_threshold=0.,prior_gmm=gmm, relevance_factor=4.)
   #map_adapt.set_prior_gmm(gmm)
 
   gmm_adapted = GMMMachine(2,50)
@@ -186,7 +184,7 @@ def test_gmm_MAP_3():
   max_iter_gmm = 1
   accuracy = 0.00001
   map_factor = 0.5
-  map_gmmtrainer = MAP_GMMTrainer(GMMBaseTrainer(True, False, False, prior), prior_gmm, alpha=map_factor)
+  map_gmmtrainer = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, convergence_threshold=prior, prior_gmm=prior_gmm, alpha=map_factor)
   map_gmmtrainer.max_iterations = max_iter_gmm
   map_gmmtrainer.convergence_threshold = accuracy
 
diff --git a/setup.py b/setup.py
index aa58b9c..875085d 100644
--- a/setup.py
+++ b/setup.py
@@ -113,7 +113,7 @@ setup(
           "bob/learn/misc/gmm_machine.cpp",
           "bob/learn/misc/kmeans_machine.cpp",
           "bob/learn/misc/kmeans_trainer.cpp",
-          "bob/learn/misc/gmm_base_trainer.cpp",
+
           "bob/learn/misc/ML_gmm_trainer.cpp",
           "bob/learn/misc/MAP_gmm_trainer.cpp",
 
-- 
GitLab