From 90f09f63432b1b9c1d48bd79a76a1a901cd41c36 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 22 Jan 2015 12:54:03 +0100
Subject: [PATCH] Binding GMMBaseTrainer class

---
 .../{GMMTrainer.cpp => GMMBaseTrainer.cpp}    |  59 ++-
 bob/learn/misc/gmm_base_trainer.cpp           | 348 ++++++++++++++++++
 .../{GMMTrainer.h => GMMBaseTrainer.h}        |  37 +-
 bob/learn/misc/kmeans_trainer.cpp             |   2 +-
 bob/learn/misc/main.cpp                       |   5 +-
 bob/learn/misc/main.h                         |  14 +
 setup.py                                      |   3 +-
 7 files changed, 404 insertions(+), 64 deletions(-)
 rename bob/learn/misc/cpp/{GMMTrainer.cpp => GMMBaseTrainer.cpp} (53%)
 create mode 100644 bob/learn/misc/gmm_base_trainer.cpp
 rename bob/learn/misc/include/bob.learn.misc/{GMMTrainer.h => GMMBaseTrainer.h} (76%)

diff --git a/bob/learn/misc/cpp/GMMTrainer.cpp b/bob/learn/misc/cpp/GMMBaseTrainer.cpp
similarity index 53%
rename from bob/learn/misc/cpp/GMMTrainer.cpp
rename to bob/learn/misc/cpp/GMMBaseTrainer.cpp
index e1eb8db..8210f2b 100644
--- a/bob/learn/misc/cpp/GMMTrainer.cpp
+++ b/bob/learn/misc/cpp/GMMBaseTrainer.cpp
@@ -5,39 +5,33 @@
  * Copyright (C) Idiap Research Institute, Martigny, Switzerland
  */
 
-#include <bob.learn.misc/GMMTrainer.h>
+#include <bob.learn.misc/GMMBaseTrainer.h>
 #include <bob.core/assert.h>
 #include <bob.core/check.h>
 
-bob::learn::misc::GMMTrainer::GMMTrainer(const bool update_means,
+bob::learn::misc::GMMBaseTrainer::GMMBaseTrainer(const bool update_means,
     const bool update_variances, const bool update_weights,
     const double mean_var_update_responsibilities_threshold):
-  bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<double,2> >(),
   m_update_means(update_means), m_update_variances(update_variances),
   m_update_weights(update_weights),
   m_mean_var_update_responsibilities_threshold(mean_var_update_responsibilities_threshold)
-{
-}
+{}
 
-bob::learn::misc::GMMTrainer::GMMTrainer(const bob::learn::misc::GMMTrainer& b):
-  bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<double,2> >(b),
+bob::learn::misc::GMMBaseTrainer::GMMBaseTrainer(const bob::learn::misc::GMMBaseTrainer& b):
   m_update_means(b.m_update_means), m_update_variances(b.m_update_variances),
   m_mean_var_update_responsibilities_threshold(b.m_mean_var_update_responsibilities_threshold)
-{
-}
+{}
 
-bob::learn::misc::GMMTrainer::~GMMTrainer()
-{
-}
+bob::learn::misc::GMMBaseTrainer::~GMMBaseTrainer()
+{}
 
-void bob::learn::misc::GMMTrainer::initialize(bob::learn::misc::GMMMachine& gmm,
-  const blitz::Array<double,2>& data)
+void bob::learn::misc::GMMBaseTrainer::initialize(bob::learn::misc::GMMMachine& gmm)
 {
   // Allocate memory for the sufficient statistics and initialise
   m_ss.resize(gmm.getNGaussians(),gmm.getNInputs());
 }
 
-void bob::learn::misc::GMMTrainer::eStep(bob::learn::misc::GMMMachine& gmm,
+void bob::learn::misc::GMMBaseTrainer::eStep(bob::learn::misc::GMMMachine& gmm,
   const blitz::Array<double,2>& data)
 {
   m_ss.init();
@@ -45,23 +39,17 @@ void bob::learn::misc::GMMTrainer::eStep(bob::learn::misc::GMMMachine& gmm,
   gmm.accStatistics(data, m_ss);
 }
 
-double bob::learn::misc::GMMTrainer::computeLikelihood(bob::learn::misc::GMMMachine& gmm)
+double bob::learn::misc::GMMBaseTrainer::computeLikelihood(bob::learn::misc::GMMMachine& gmm)
 {
   return m_ss.log_likelihood / m_ss.T;
 }
 
-void bob::learn::misc::GMMTrainer::finalize(bob::learn::misc::GMMMachine& gmm,
-  const blitz::Array<double,2>& data)
-{
-}
 
-bob::learn::misc::GMMTrainer& bob::learn::misc::GMMTrainer::operator=
-  (const bob::learn::misc::GMMTrainer &other)
+bob::learn::misc::GMMBaseTrainer& bob::learn::misc::GMMBaseTrainer::operator=
+  (const bob::learn::misc::GMMBaseTrainer &other)
 {
   if (this != &other)
   {
-    bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine,
-      blitz::Array<double,2> >::operator=(other);
     m_ss = other.m_ss;
     m_update_means = other.m_update_means;
     m_update_variances = other.m_update_variances;
@@ -71,32 +59,27 @@ bob::learn::misc::GMMTrainer& bob::learn::misc::GMMTrainer::operator=
   return *this;
 }
 
-bool bob::learn::misc::GMMTrainer::operator==
-  (const bob::learn::misc::GMMTrainer &other) const
+bool bob::learn::misc::GMMBaseTrainer::operator==
+  (const bob::learn::misc::GMMBaseTrainer &other) const
 {
-  return bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine,
-           blitz::Array<double,2> >::operator==(other) &&
-         m_ss == other.m_ss &&
+  return m_ss == other.m_ss &&
          m_update_means == other.m_update_means &&
          m_update_variances == other.m_update_variances &&
          m_update_weights == other.m_update_weights &&
          m_mean_var_update_responsibilities_threshold == other.m_mean_var_update_responsibilities_threshold;
 }
 
-bool bob::learn::misc::GMMTrainer::operator!=
-  (const bob::learn::misc::GMMTrainer &other) const
+bool bob::learn::misc::GMMBaseTrainer::operator!=
+  (const bob::learn::misc::GMMBaseTrainer &other) const
 {
   return !(this->operator==(other));
 }
 
-bool bob::learn::misc::GMMTrainer::is_similar_to
-  (const bob::learn::misc::GMMTrainer &other, const double r_epsilon,
+bool bob::learn::misc::GMMBaseTrainer::is_similar_to
+  (const bob::learn::misc::GMMBaseTrainer &other, const double r_epsilon,
    const double a_epsilon) const
 {
-  return bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine,
-           blitz::Array<double,2> >::operator==(other) &&
-  // TODO: use is similar to method for the accumulator m_ss
-         m_ss == other.m_ss &&
+  return m_ss == other.m_ss &&
          m_update_means == other.m_update_means &&
          m_update_variances == other.m_update_variances &&
          m_update_weights == other.m_update_weights &&
@@ -104,7 +87,7 @@ bool bob::learn::misc::GMMTrainer::is_similar_to
           other.m_mean_var_update_responsibilities_threshold, r_epsilon, a_epsilon);
 }
 
-void bob::learn::misc::GMMTrainer::setGMMStats(const bob::learn::misc::GMMStats& stats)
+void bob::learn::misc::GMMBaseTrainer::setGMMStats(const bob::learn::misc::GMMStats& stats)
 {
   bob::core::array::assertSameShape(m_ss.sumPx, stats.sumPx);
   m_ss = stats;
diff --git a/bob/learn/misc/gmm_base_trainer.cpp b/bob/learn/misc/gmm_base_trainer.cpp
new file mode 100644
index 0000000..5f133af
--- /dev/null
+++ b/bob/learn/misc/gmm_base_trainer.cpp
@@ -0,0 +1,348 @@
+/**
+ * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+ * @date Web 21 Jan 12:30:00 2015
+ *
+ * @brief Python API for bob::learn::em
+ *
+ * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
+ */
+
+#include "main.h"
+
+/******************************************************************/
+/************ Constructor Section *********************************/
+/******************************************************************/
+
+static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}  /* converts PyObject to bool and returns false if object is NULL */
+
+static auto GMMBaseTrainer_doc = bob::extension::ClassDoc(
+  BOB_EXT_MODULE_PREFIX ".GMMBaseTrainer",
+  "This class implements the E-step of the expectation-maximisation"
+  "algorithm for a :py:class:`bob.learn.misc.GMMMachine`"
+).add_constructor(
+  bob::extension::FunctionDoc(
+    "__init__",
+    "Creates a GMMBaseTrainer",
+    "",
+    true
+  )
+  .add_prototype("update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
+  .add_prototype("other","")
+  .add_prototype("","")
+
+  .add_parameter("update_means", "bool", "Update means on each iteration")
+  .add_parameter("update_variances", "bool", "Update variances on each iteration")
+  .add_parameter("update_weights", "bool", "Update weights on each iteration")
+  .add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.")
+  .add_parameter("other", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A GMMBaseTrainer object to be copied.")
+);
+
+
+
+static int PyBobLearnMiscGMMBaseTrainer_init_copy(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = GMMBaseTrainer_doc.kwlist(1);
+  PyBobLearnMiscGMMBaseTrainerObject* tt;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMBaseTrainer_Type, &tt)){
+    GMMBaseTrainer_doc.print_usage();
+    return -1;
+  }
+
+  self->cxx.reset(new bob::learn::misc::GMMBaseTrainer(*tt->cxx));
+  return 0;
+}
+
+
+static int PyBobLearnMiscGMMBaseTrainer_init_bool(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = GMMBaseTrainer_doc.kwlist(0);
+  PyObject* update_means     = 0;
+  PyObject* update_variances = 0;
+  PyObject* update_weights   = 0;
+  double mean_var_update_responsibilities_threshold = 0; //std::numeric_limits<double>::epsilon();
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!d", kwlist, &PyBool_Type, &update_means, &PyBool_Type, 
+                                                             &update_variances, &PyBool_Type, &update_weights, &mean_var_update_responsibilities_threshold)){
+    GMMBaseTrainer_doc.print_usage();
+    return -1;
+  }
+  self->cxx.reset(new bob::learn::misc::GMMBaseTrainer(f(update_means), f(update_variances), f(update_weights), mean_var_update_responsibilities_threshold));
+  return 0;
+}
+
+
+static int PyBobLearnMiscGMMBaseTrainer_init(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
+
+  if (nargs==0){ //default initializer ()
+    self->cxx.reset(new bob::learn::misc::GMMBaseTrainer());
+    return 0;
+  }
+  else{
+    //Reading the input argument
+    PyObject* arg = 0;
+    if (PyTuple_Size(args))
+      arg = PyTuple_GET_ITEM(args, 0);
+    else {
+      PyObject* tmp = PyDict_Values(kwargs);
+      auto tmp_ = make_safe(tmp);
+      arg = PyList_GET_ITEM(tmp, 0);
+    }
+
+    // If the constructor input is GMMBaseTrainer object
+    if (PyBobLearnMiscGMMBaseTrainer_Check(arg))
+      return PyBobLearnMiscGMMBaseTrainer_init_copy(self, args, kwargs);
+    else
+      return PyBobLearnMiscGMMBaseTrainer_init_bool(self, args, kwargs);
+  }
+
+  BOB_CATCH_MEMBER("cannot create GMMBaseTrainer_init_bool", 0)
+  return 0;
+}
+
+
+static void PyBobLearnMiscGMMBaseTrainer_delete(PyBobLearnMiscGMMBaseTrainerObject* self) {
+  self->cxx.reset();
+  Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+
+int PyBobLearnMiscGMMBaseTrainer_Check(PyObject* o) {
+  return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscGMMBaseTrainer_Type));
+}
+
+
+static PyObject* PyBobLearnMiscGMMBaseTrainer_RichCompare(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* other, int op) {
+  BOB_TRY
+
+  if (!PyBobLearnMiscGMMBaseTrainer_Check(other)) {
+    PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name);
+    return 0;
+  }
+  auto other_ = reinterpret_cast<PyBobLearnMiscGMMBaseTrainerObject*>(other);
+  switch (op) {
+    case Py_EQ:
+      if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE;
+    case Py_NE:
+      if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE;
+    default:
+      Py_INCREF(Py_NotImplemented);
+      return Py_NotImplemented;
+  }
+  BOB_CATCH_MEMBER("cannot compare GMMBaseTrainer objects", 0)
+}
+
+
+/******************************************************************/
+/************ Variables Section ***********************************/
+/******************************************************************/
+
+
+/***** gmm_stats *****/
+static auto gmm_stats = bob::extension::VariableDoc(
+  "gmm_stats",
+  ":py:class:`bob.learn.misc.GMMStats`",
+  "Get/Set GMMStats",
+  ""
+);
+PyObject* PyBobLearnMiscGMMBaseTrainer_getGMMStats(PyBobLearnMiscGMMBaseTrainerObject* self, void*){
+  BOB_TRY
+
+  //bob::learn::misc::GMMStats stats = self->cxx->getGMMStats();
+  //boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = boost::make_shared(stats);
+  boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = 0;
+
+
+  //Allocating the correspondent python object
+  PyBobLearnMiscGMMStatsObject* retval =
+    (PyBobLearnMiscGMMStatsObject*)PyBobLearnMiscGMMStats_Type.tp_alloc(&PyBobLearnMiscGMMStats_Type, 0);
+
+  retval->cxx = stats_shared;
+
+  return Py_BuildValue("O",retval);
+  BOB_CATCH_MEMBER("GMMStats could not be read", 0)
+}
+int PyBobLearnMiscGMMBaseTrainer_setGMMStats(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* value, void*){
+  BOB_TRY
+
+  if (!PyBobLearnMiscGMMStats_Check(value)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMStats`", Py_TYPE(self)->tp_name, gmm_stats.name());
+    return -1;
+  }
+
+  PyBobLearnMiscGMMStatsObject* stats = 0;
+  PyArg_Parse(value, "O!", &PyBobLearnMiscGMMStats_Type,&stats);
+
+  self->cxx->setGMMStats(*stats->cxx);
+
+  return 0;
+  BOB_CATCH_MEMBER("gmm_stats could not be set", -1)  
+}
+
+
+
+static PyGetSetDef PyBobLearnMiscGMMBaseTrainer_getseters[] = { 
+  {
+    gmm_stats.name(),
+    (getter)PyBobLearnMiscGMMBaseTrainer_getGMMStats,
+    (setter)PyBobLearnMiscGMMBaseTrainer_setGMMStats,
+    gmm_stats.doc(),
+    0
+  },
+  {0}  // Sentinel
+};
+
+
+/******************************************************************/
+/************ Functions Section ***********************************/
+/******************************************************************/
+
+/*** initialize ***/
+static auto initialize = bob::extension::FunctionDoc(
+  "initialize",
+  "Initialization before the EM steps",
+  "Instanciate :py:class:`bob.learn.misc.GMMStats`",
+  true
+)
+.add_prototype("gmm_machine")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object");
+static PyObject* PyBobLearnMiscGMMBaseTrainer_initialize(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = initialize.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE;
+
+  self->cxx->initialize(*gmm_machine->cxx);
+
+  BOB_CATCH_MEMBER("cannot perform the initialize method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+
+/*** eStep ***/
+static auto eStep = bob::extension::FunctionDoc(
+  "eStep",
+  "Calculates and saves statistics across the dataset,"
+  "and saves these as m_ss. ",
+
+  "Calculates the average log likelihood of the observations given the GMM,"
+  "and returns this in average_log_likelihood."
+  "The statistics, m_ss, will be used in the mStep() that follows.",
+
+  true
+)
+.add_prototype("gmm_machine,data")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object")
+.add_parameter("data", "array_like <float, 2D>", "Input data");
+static PyObject* PyBobLearnMiscGMMBaseTrainer_eStep(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = eStep.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine;
+  PyBlitzArrayObject* data = 0;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine,
+                                                                 &PyBlitzArray_Converter, &data)) Py_RETURN_NONE;
+  auto data_ = make_safe(data);
+
+  self->cxx->eStep(*gmm_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data));
+
+  BOB_CATCH_MEMBER("cannot perform the eStep method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+/*** computeLikelihood ***/
+static auto compute_likelihood = bob::extension::FunctionDoc(
+  "compute_likelihood",
+  "This functions returns the average min (Square Euclidean) distance (average distance to the closest mean)",
+  0,
+  true
+)
+.add_prototype("gmm_machine")
+.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object");
+static PyObject* PyBobLearnMiscGMMBaseTrainer_compute_likelihood(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = compute_likelihood.kwlist(0);
+
+  PyBobLearnMiscGMMMachineObject* gmm_machine;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE;
+
+  double value = self->cxx->computeLikelihood(*gmm_machine->cxx);
+  return Py_BuildValue("d", value);
+
+  BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0)
+}
+
+
+static PyMethodDef PyBobLearnMiscGMMBaseTrainer_methods[] = {
+  {
+    initialize.name(),
+    (PyCFunction)PyBobLearnMiscGMMBaseTrainer_initialize,
+    METH_VARARGS|METH_KEYWORDS,
+    initialize.doc()
+  },
+  {
+    eStep.name(),
+    (PyCFunction)PyBobLearnMiscGMMBaseTrainer_eStep,
+    METH_VARARGS|METH_KEYWORDS,
+    eStep.doc()
+  },
+  {
+    compute_likelihood.name(),
+    (PyCFunction)PyBobLearnMiscGMMBaseTrainer_compute_likelihood,
+    METH_VARARGS|METH_KEYWORDS,
+    compute_likelihood.doc()
+  },
+  {0} /* Sentinel */
+};
+
+
+/******************************************************************/
+/************ Module Section **************************************/
+/******************************************************************/
+
+// Define the Gaussian type struct; will be initialized later
+PyTypeObject PyBobLearnMiscGMMBaseTrainer_Type = {
+  PyVarObject_HEAD_INIT(0,0)
+  0
+};
+
+bool init_BobLearnMiscGMMBaseTrainer(PyObject* module)
+{
+  // initialize the type struct
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_name      = GMMBaseTrainer_doc.name();
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscGMMBaseTrainerObject);
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_flags     = Py_TPFLAGS_DEFAULT;
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_doc       = GMMBaseTrainer_doc.doc();
+
+  // set the functions
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_new          = PyType_GenericNew;
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_init         = reinterpret_cast<initproc>(PyBobLearnMiscGMMBaseTrainer_init);
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_dealloc      = reinterpret_cast<destructor>(PyBobLearnMiscGMMBaseTrainer_delete);
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscGMMBaseTrainer_RichCompare);
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_methods      = PyBobLearnMiscGMMBaseTrainer_methods;
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_getset       = PyBobLearnMiscGMMBaseTrainer_getseters;
+  PyBobLearnMiscGMMBaseTrainer_Type.tp_call         = reinterpret_cast<ternaryfunc>(PyBobLearnMiscGMMBaseTrainer_compute_likelihood);
+
+
+  // check that everything is fine
+  if (PyType_Ready(&PyBobLearnMiscGMMBaseTrainer_Type) < 0) return false;
+
+  // add the type to the module
+  Py_INCREF(&PyBobLearnMiscGMMBaseTrainer_Type);
+  return PyModule_AddObject(module, "GMMBaseTrainer", (PyObject*)&PyBobLearnMiscGMMBaseTrainer_Type) >= 0;
+}
+
diff --git a/bob/learn/misc/include/bob.learn.misc/GMMTrainer.h b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
similarity index 76%
rename from bob/learn/misc/include/bob.learn.misc/GMMTrainer.h
rename to bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
index fbbef1b..1cb0788 100644
--- a/bob/learn/misc/include/bob.learn.misc/GMMTrainer.h
+++ b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h
@@ -8,10 +8,9 @@
  * Copyright (C) Idiap Research Institute, Martigny, Switzerland
  */
 
-#ifndef BOB_LEARN_MISC_GMMTRAINER_H
-#define BOB_LEARN_MISC_GMMTRAINER_H
+#ifndef BOB_LEARN_MISC_GMMBASETRAINER_H
+#define BOB_LEARN_MISC_GMMBASETRAINER_H
 
-#include <bob.learn.misc/EMTrainer.h>
 #include <bob.learn.misc/GMMMachine.h>
 #include <bob.learn.misc/GMMStats.h>
 #include <limits>
@@ -24,13 +23,13 @@ namespace bob { namespace learn { namespace misc {
  * @details See Section 9.2.2 of Bishop,
  *   "Pattern recognition and machine learning", 2006
  */
-class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<double,2> >
+class GMMBaseTrainer
 {
   public:
     /**
      * @brief Default constructor
      */
-    GMMTrainer(const bool update_means=true,
+    GMMBaseTrainer(const bool update_means=true,
       const bool update_variances=false, const bool update_weights=false,
       const double mean_var_update_responsibilities_threshold =
         std::numeric_limits<double>::epsilon());
@@ -38,18 +37,17 @@ class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<do
     /**
      * @brief Copy constructor
      */
-    GMMTrainer(const GMMTrainer& other);
+    GMMBaseTrainer(const GMMBaseTrainer& other);
 
     /**
      * @brief Destructor
      */
-    virtual ~GMMTrainer();
+    virtual ~GMMBaseTrainer();
 
     /**
      * @brief Initialization before the EM steps
      */
-    virtual void initialize(bob::learn::misc::GMMMachine& gmm,
-      const blitz::Array<double,2>& data);
+    virtual void initialize(bob::learn::misc::GMMMachine& gmm);
 
     /**
      * @brief Calculates and saves statistics across the dataset,
@@ -69,38 +67,33 @@ class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<do
      */
     virtual double computeLikelihood(bob::learn::misc::GMMMachine& gmm);
 
-    /**
-     * @brief Finalization after the EM steps
-     */
-    virtual void finalize(bob::learn::misc::GMMMachine& gmm,
-      const blitz::Array<double,2>& data);
 
     /**
-     * @brief Assigns from a different GMMTrainer
+     * @brief Assigns from a different GMMBaseTrainer
      */
-    GMMTrainer& operator=(const GMMTrainer &other);
+    GMMBaseTrainer& operator=(const GMMBaseTrainer &other);
 
     /**
      * @brief Equal to
      */
-    bool operator==(const GMMTrainer& b) const;
+    bool operator==(const GMMBaseTrainer& b) const;
 
     /**
      * @brief Not equal to
      */
-    bool operator!=(const GMMTrainer& b) const;
+    bool operator!=(const GMMBaseTrainer& b) const;
 
     /**
      * @brief Similar to
      */
-    bool is_similar_to(const GMMTrainer& b, const double r_epsilon=1e-5,
+    bool is_similar_to(const GMMBaseTrainer& b, const double r_epsilon=1e-5,
       const double a_epsilon=1e-8) const;
 
     /**
      * @brief Returns the internal GMM statistics. Useful to parallelize the
      * E-step
      */
-    const bob::learn::misc::GMMStats& getGMMStats() const
+    const bob::learn::misc::GMMStats getGMMStats() const
     { return m_ss; }
 
     /**
@@ -109,7 +102,7 @@ class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<do
      */
     void setGMMStats(const bob::learn::misc::GMMStats& stats);
 
-  protected:
+  private:
     /**
      * These are the sufficient statistics, calculated during the
      * E-step and used during the M-step
@@ -142,4 +135,4 @@ class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<do
 
 } } } // namespaces
 
-#endif // BOB_LEARN_MISC_GMMTRAINER_H
+#endif // BOB_LEARN_MISC_GMMBASETRAINER_H
diff --git a/bob/learn/misc/kmeans_trainer.cpp b/bob/learn/misc/kmeans_trainer.cpp
index 2c1111b..cbaac80 100644
--- a/bob/learn/misc/kmeans_trainer.cpp
+++ b/bob/learn/misc/kmeans_trainer.cpp
@@ -540,7 +540,7 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module)
   PyBobLearnMiscKMeansTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscKMeansTrainer_RichCompare);
   PyBobLearnMiscKMeansTrainer_Type.tp_methods = PyBobLearnMiscKMeansTrainer_methods;
   PyBobLearnMiscKMeansTrainer_Type.tp_getset = PyBobLearnMiscKMeansTrainer_getseters;
-  PyBobLearnMiscGMMMachine_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscKMeansTrainer_compute_likelihood);
+  PyBobLearnMiscKMeansTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscKMeansTrainer_compute_likelihood);
 
 
   // check that everything is fine
diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp
index fb8b2a5..0431832 100644
--- a/bob/learn/misc/main.cpp
+++ b/bob/learn/misc/main.cpp
@@ -43,8 +43,9 @@ static PyObject* create_module (void) {
   if (!init_BobLearnMiscGaussian(module)) return 0;
   if (!init_BobLearnMiscGMMStats(module)) return 0;
   if (!init_BobLearnMiscGMMMachine(module)) return 0;
-  if (!init_BobLearnMiscKMeansMachine(module)) return 0; 
-  if (!init_BobLearnMiscKMeansTrainer(module)) return 0;  
+  if (!init_BobLearnMiscKMeansMachine(module)) return 0;
+  if (!init_BobLearnMiscKMeansTrainer(module)) return 0;
+  if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0;  
 
 
   static void* PyBobLearnMisc_API[PyBobLearnMisc_API_pointers];
diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h
index a6efaa7..cffea0e 100644
--- a/bob/learn/misc/main.h
+++ b/bob/learn/misc/main.h
@@ -23,6 +23,9 @@
 #include <bob.learn.misc/KMeansMachine.h>
 #include <bob.learn.misc/KMeansTrainer.h>
 
+#include <bob.learn.misc/GMMBaseTrainer.h>
+
+
 #if PY_VERSION_HEX >= 0x03000000
 #define PyInt_Check PyLong_Check
 #define PyInt_AS_LONG PyLong_AS_LONG
@@ -116,5 +119,16 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module);
 int PyBobLearnMiscKMeansTrainer_Check(PyObject* o);
 
 
+// GMMBaseTrainer
+typedef struct {
+  PyObject_HEAD
+  boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> cxx;
+} PyBobLearnMiscGMMBaseTrainerObject;
+
+extern PyTypeObject PyBobLearnMiscGMMBaseTrainer_Type;
+bool init_BobLearnMiscGMMBaseTrainer(PyObject* module);
+int PyBobLearnMiscGMMBaseTrainer_Check(PyObject* o);
+
+
 
 #endif // BOB_LEARN_EM_MAIN_H
diff --git a/setup.py b/setup.py
index 57f2a77..9bc5e23 100644
--- a/setup.py
+++ b/setup.py
@@ -65,7 +65,7 @@ setup(
           #"bob/learn/misc/cpp/ZTNorm.cpp",
 
           #"bob/learn/misc/cpp/EMPCATrainer.cpp",
-          #"bob/learn/misc/cpp/GMMTrainer.cpp",
+          "bob/learn/misc/cpp/GMMBaseTrainer.cpp",
           #"bob/learn/misc/cpp/IVectorTrainer.cpp",
           #"bob/learn/misc/cpp/JFATrainer.cpp",
           "bob/learn/misc/cpp/KMeansTrainer.cpp",
@@ -105,6 +105,7 @@ setup(
           "bob/learn/misc/gmm_machine.cpp",
           "bob/learn/misc/kmeans_machine.cpp",
           "bob/learn/misc/kmeans_trainer.cpp",
+          "bob/learn/misc/gmm_base_trainer.cpp",
 
           "bob/learn/misc/main.cpp",
         ],
-- 
GitLab