From d9093beed5093f0892c7049acec3ae1c09d7691f Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Wed, 4 Feb 2015 15:30:20 +0100
Subject: [PATCH] Binding PLDATrainer

---
 bob/learn/misc/__empca_trainer__.py           |  77 +++
 bob/learn/misc/cpp/PLDATrainer.cpp            |  37 +-
 bob/learn/misc/empca_trainer.cpp              |   3 +-
 .../misc/include/bob.learn.misc/PLDATrainer.h |  31 +-
 bob/learn/misc/main.cpp                       |   1 +
 bob/learn/misc/main.h                         |  16 +
 bob/learn/misc/plda_trainer.cpp               | 458 ++++++++++++++++++
 setup.py                                      |   4 +-
 8 files changed, 588 insertions(+), 39 deletions(-)
 create mode 100644 bob/learn/misc/__empca_trainer__.py
 create mode 100644 bob/learn/misc/plda_trainer.cpp

diff --git a/bob/learn/misc/__empca_trainer__.py b/bob/learn/misc/__empca_trainer__.py
new file mode 100644
index 0000000..28f4c9d
--- /dev/null
+++ b/bob/learn/misc/__empca_trainer__.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+# Wed Fev 04 13:35:10 2015 +0200
+#
+# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
+
+from ._library import _EMPCATrainer
+import numpy
+
+# define the class
+class EMPCATrainer (_EMPCATrainer):
+
+  def __init__(self, convergence_threshold=0.001, max_iterations=10, compute_likelihood=True):
+    """
+    :py:class:`bob.learn.misc.EMPCATrainer` constructor
+
+    Keyword Parameters:
+      convergence_threshold
+        Convergence threshold
+      max_iterations
+        Number of maximum iterations
+      compute_likelihood
+        
+    """
+
+    _EMPCATrainer.__init__(self,convergence_threshold)
+    self._max_iterations        = max_iterations
+    self._compute_likelihood    = compute_likelihood
+
+
+  def train(self, linear_machine, data):
+    """
+    Train the :py:class:bob.learn.misc.LinearMachine using data
+
+    Keyword Parameters:
+      linear_machine
+        The :py:class:bob.learn.misc.LinearMachine class
+      data
+        The data to be trained
+    """
+
+    #Initialization
+    self.initialize(linear_machine, data);
+      
+    #Do the Expectation-Maximization algorithm
+    average_output_previous = 0
+    average_output = -numpy.inf;
+
+    #eStep
+    self.eStep(linear_machine, data);
+
+    if(self._compute_likelihood):
+      average_output = self.compute_likelihood(linear_machine);
+    
+    for i in range(self._max_iterations):
+
+      #saves average output from last iteration
+      average_output_previous = average_output;
+
+      #mStep
+      self.mStep(linear_machine);
+
+      #eStep
+      self.eStep(linear_machine, data);
+
+      #Computes log likelihood if required
+      if(self._compute_likelihood):
+        average_output = self.compute_likelihood(linear_machine);
+
+        #Terminates if converged (and likelihood computation is set)
+        if abs((average_output_previous - average_output)/average_output_previous) <= self._convergence_threshold:
+          break
+
+
+# copy the documentation from the base class
+__doc__ = _EMPCATrainer.__doc__
diff --git a/bob/learn/misc/cpp/PLDATrainer.cpp b/bob/learn/misc/cpp/PLDATrainer.cpp
index 725dd45..af50140 100644
--- a/bob/learn/misc/cpp/PLDATrainer.cpp
+++ b/bob/learn/misc/cpp/PLDATrainer.cpp
@@ -7,20 +7,26 @@
  * Copyright (C) Idiap Research Institute, Martigny, Switzerland
  */
 
+
 #include <bob.learn.misc/PLDATrainer.h>
+#include <bob.core/check.h>
 #include <bob.core/array_copy.h>
 #include <bob.core/array_random.h>
-#include <bob.math/linear.h>
 #include <bob.math/inv.h>
 #include <bob.math/svd.h>
+#include <bob.core/check.h>
+#include <bob.core/array_repmat.h>
 #include <algorithm>
-#include <vector>
 #include <limits>
+#include <vector>
+
+#include <bob.math/linear.h>
+#include <bob.math/linsolve.h>
+
+
 
-bob::learn::misc::PLDATrainer::PLDATrainer(const size_t max_iterations,
-    const bool use_sum_second_order):
-  EMTrainer<bob::learn::misc::PLDABase, std::vector<blitz::Array<double,2> > >
-    (0.001, max_iterations, false),
+bob::learn::misc::PLDATrainer::PLDATrainer(const bool use_sum_second_order):
+  m_rng(new boost::mt19937()),
   m_dim_d(0), m_dim_f(0), m_dim_g(0),
   m_use_sum_second_order(use_sum_second_order),
   m_initF_method(bob::learn::misc::PLDATrainer::RANDOM_F), m_initF_ratio(1.),
@@ -38,9 +44,7 @@ bob::learn::misc::PLDATrainer::PLDATrainer(const size_t max_iterations,
 }
 
 bob::learn::misc::PLDATrainer::PLDATrainer(const bob::learn::misc::PLDATrainer& other):
-  EMTrainer<bob::learn::misc::PLDABase, std::vector<blitz::Array<double,2> > >
-    (other.m_convergence_threshold, other.m_max_iterations,
-     other.m_compute_likelihood),
+  m_rng(other.m_rng),
   m_dim_d(other.m_dim_d), m_dim_f(other.m_dim_f), m_dim_g(other.m_dim_g),
   m_use_sum_second_order(other.m_use_sum_second_order),
   m_initF_method(other.m_initF_method), m_initF_ratio(other.m_initF_ratio),
@@ -71,8 +75,7 @@ bob::learn::misc::PLDATrainer& bob::learn::misc::PLDATrainer::operator=
 {
   if(this != &other)
   {
-    bob::learn::misc::EMTrainer<bob::learn::misc::PLDABase,
-      std::vector<blitz::Array<double,2> > >::operator=(other);
+    m_rng = m_rng,
     m_dim_d = other.m_dim_d;
     m_dim_f = other.m_dim_f;
     m_dim_g = other.m_dim_g;
@@ -102,8 +105,7 @@ bob::learn::misc::PLDATrainer& bob::learn::misc::PLDATrainer::operator=
 bool bob::learn::misc::PLDATrainer::operator==
   (const bob::learn::misc::PLDATrainer& other) const
 {
-  return bob::learn::misc::EMTrainer<bob::learn::misc::PLDABase,
-           std::vector<blitz::Array<double,2> > >::operator==(other) &&
+  return m_rng == m_rng &&
          m_dim_d == other.m_dim_d &&
          m_dim_f == other.m_dim_f &&
          m_dim_g == other.m_dim_g &&
@@ -138,8 +140,7 @@ bool bob::learn::misc::PLDATrainer::is_similar_to
   (const bob::learn::misc::PLDATrainer &other, const double r_epsilon,
    const double a_epsilon) const
 {
-  return bob::learn::misc::EMTrainer<bob::learn::misc::PLDABase,
-           std::vector<blitz::Array<double,2> > >::is_similar_to(other, r_epsilon, a_epsilon) &&
+  return m_rng == m_rng &&
          m_dim_d == other.m_dim_d &&
          m_dim_f == other.m_dim_f &&
          m_dim_g == other.m_dim_g &&
@@ -745,12 +746,6 @@ void bob::learn::misc::PLDATrainer::updateSigma(bob::learn::misc::PLDABase& mach
   machine.applyVarianceThreshold();
 }
 
-double bob::learn::misc::PLDATrainer::computeLikelihood(bob::learn::misc::PLDABase& machine)
-{
-  double llh = 0.;
-  // TODO: implement log likelihood computation
-  return llh;
-}
 
 void bob::learn::misc::PLDATrainer::enrol(bob::learn::misc::PLDAMachine& plda_machine,
   const blitz::Array<double,2>& ar) const
diff --git a/bob/learn/misc/empca_trainer.cpp b/bob/learn/misc/empca_trainer.cpp
index 6a34239..1b7a696 100644
--- a/bob/learn/misc/empca_trainer.cpp
+++ b/bob/learn/misc/empca_trainer.cpp
@@ -24,11 +24,12 @@ static auto EMPCATrainer_doc = bob::extension::ClassDoc(
     "",
     true
   )
-  .add_prototype("compute_likelihood","")
+  .add_prototype("convergence_threshold","")
   .add_prototype("other","")
   .add_prototype("","")
 
   .add_parameter("other", ":py:class:`bob.learn.misc.EMPCATrainer`", "A EMPCATrainer object to be copied.")
+  .add_parameter("convergence_threshold", "double", "")
 
 );
 
diff --git a/bob/learn/misc/include/bob.learn.misc/PLDATrainer.h b/bob/learn/misc/include/bob.learn.misc/PLDATrainer.h
index d80372e..3439490 100644
--- a/bob/learn/misc/include/bob.learn.misc/PLDATrainer.h
+++ b/bob/learn/misc/include/bob.learn.misc/PLDATrainer.h
@@ -11,11 +11,13 @@
 #ifndef BOB_LEARN_MISC_PLDA_TRAINER_H
 #define BOB_LEARN_MISC_PLDA_TRAINER_H
 
-#include <bob.learn.misc/EMTrainer.h>
 #include <bob.learn.misc/PLDAMachine.h>
-#include <blitz/array.h>
-#include <map>
+#include <boost/shared_ptr.hpp>
 #include <vector>
+#include <map>
+#include <bob.core/array_copy.h>
+#include <boost/random.hpp>
+#include <boost/random/mersenne_twister.hpp>
 
 namespace bob { namespace learn { namespace misc {
 
@@ -31,8 +33,7 @@ namespace bob { namespace learn { namespace misc {
  * 3. 'Probabilistic Models for Inference about Identity', Li, Fu, Mohammed,
  *     Elder and Prince, TPAMI'2012
  */
-class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
-                                        std::vector<blitz::Array<double,2> > >
+class PLDATrainer
 {
   public: //api
     /**
@@ -40,7 +41,7 @@ class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
      * training stage will place the resulting components in the
      * PLDABase.
      */
-    PLDATrainer(const size_t max_iterations=100, const bool use_sum_second_order=true);
+    PLDATrainer(const bool use_sum_second_order);
 
     /**
      * @brief Copy constructor
@@ -70,18 +71,18 @@ class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
     /**
      * @brief Similarity operator
      */
-    virtual bool is_similar_to(const PLDATrainer& b,
+    bool is_similar_to(const PLDATrainer& b,
       const double r_epsilon=1e-5, const double a_epsilon=1e-8) const;
 
     /**
      * @brief Performs some initialization before the E- and M-steps.
      */
-    virtual void initialize(bob::learn::misc::PLDABase& machine,
+    void initialize(bob::learn::misc::PLDABase& machine,
       const std::vector<blitz::Array<double,2> >& v_ar);
     /**
      * @brief Performs some actions after the end of the E- and M-steps.
       */
-    virtual void finalize(bob::learn::misc::PLDABase& machine,
+    void finalize(bob::learn::misc::PLDABase& machine,
       const std::vector<blitz::Array<double,2> >& v_ar);
 
     /**
@@ -89,21 +90,16 @@ class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
      * these as m_z_{first,second}_order.
      * The statistics will be used in the mStep() that follows.
      */
-    virtual void eStep(bob::learn::misc::PLDABase& machine,
+    void eStep(bob::learn::misc::PLDABase& machine,
       const std::vector<blitz::Array<double,2> >& v_ar);
 
     /**
      * @brief Performs a maximization step to update the parameters of the
      * PLDABase
      */
-    virtual void mStep(bob::learn::misc::PLDABase& machine,
+    void mStep(bob::learn::misc::PLDABase& machine,
        const std::vector<blitz::Array<double,2> >& v_ar);
 
-    /**
-     * @brief Computes the average log likelihood using the current estimates
-     * of the latent variables.
-     */
-    virtual double computeLikelihood(bob::learn::misc::PLDABase& machine);
 
     /**
      * @brief Sets whether the second order statistics are stored during the
@@ -223,6 +219,9 @@ class PLDATrainer: public EMTrainer<bob::learn::misc::PLDABase,
       const blitz::Array<double,2>& ar) const;
 
   private:
+  
+    boost::shared_ptr<boost::mt19937> m_rng;
+  
     //representation
     size_t m_dim_d; ///< Dimensionality of the input features
     size_t m_dim_f; ///< Size/rank of the \f$F\f$ subspace
diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp
index a56592c..c2171ab 100644
--- a/bob/learn/misc/main.cpp
+++ b/bob/learn/misc/main.cpp
@@ -119,6 +119,7 @@ static PyObject* create_module (void) {
   if (import_bob_blitz() < 0) return 0;
   if (import_bob_core_random() < 0) return 0;
   if (import_bob_io_base() < 0) return 0;
+  //if (import_bob_learn_linear() < 0) return 0;
 
   Py_INCREF(module);
   return module;
diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h
index c3f9d3d..36b2eed 100644
--- a/bob/learn/misc/main.h
+++ b/bob/learn/misc/main.h
@@ -12,7 +12,9 @@
 #include <bob.blitz/cleanup.h>
 #include <bob.core/random_api.h>
 #include <bob.io.base/api.h>
+
 #include <bob.learn.linear/api.h>
+
 #include <bob.extension/documentation.h>
 
 #define BOB_LEARN_EM_MODULE
@@ -43,6 +45,8 @@
 #include <bob.learn.misc/EMPCATrainer.h>
 
 #include <bob.learn.misc/PLDAMachine.h>
+#include <bob.learn.misc/PLDATrainer.h>
+
 #include <bob.learn.misc/ZTNorm.h>
 
 
@@ -282,6 +286,18 @@ bool init_BobLearnMiscPLDAMachine(PyObject* module);
 int PyBobLearnMiscPLDAMachine_Check(PyObject* o);
 
 
+// PLDATrainer
+typedef struct {
+  PyObject_HEAD
+  boost::shared_ptr<bob::learn::misc::PLDATrainer> cxx;
+} PyBobLearnMiscPLDATrainerObject;
+
+extern PyTypeObject PyBobLearnMiscPLDATrainer_Type;
+bool init_BobLearnMiscPLDATrainer(PyObject* module);
+int PyBobLearnMiscPLDATrainer_Check(PyObject* o);
+
+
+
 // EMPCATrainer
 typedef struct {
   PyObject_HEAD
diff --git a/bob/learn/misc/plda_trainer.cpp b/bob/learn/misc/plda_trainer.cpp
new file mode 100644
index 0000000..80010d1
--- /dev/null
+++ b/bob/learn/misc/plda_trainer.cpp
@@ -0,0 +1,458 @@
+/**
+ * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+ * @date Wed 04 Feb 14:15:00 2015
+ *
+ * @brief Python API for bob::learn::em
+ *
+ * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
+ */
+
+#include "main.h"
+#include <boost/make_shared.hpp>
+
+/******************************************************************/
+/************ Constructor Section *********************************/
+/******************************************************************/
+
+static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;}  /* converts PyObject to bool and returns false if object is NULL */
+
+template <int N>
+int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
+{
+  for (int i=0; i<PyList_GET_SIZE(list); i++)
+  {
+    PyBlitzArrayObject* blitz_object; 
+    if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){
+      PyErr_Format(PyExc_RuntimeError, "Expected numpy array object");
+      return -1;
+    }
+    auto blitz_object_ = make_safe(blitz_object);
+    vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object));
+  }
+  return 0;
+}
+
+
+static auto PLDATrainer_doc = bob::extension::ClassDoc(
+  BOB_EXT_MODULE_PREFIX ".PLDATrainer",
+  "This class can be used to train the :math:`$F$`, :math:`$G$ and "
+  " :math:`$\\Sigma$` matrices and the mean vector :math:`$\\mu$` of a PLDA model.",
+  "References: [ElShafey2014,PrinceElder2007,LiFu2012]",
+).add_constructor(
+  bob::extension::FunctionDoc(
+    "__init__",
+    "Default constructor.\n Initializes a new PLDA trainer. The "
+    "training stage will place the resulting components in the "
+    "PLDABase.",
+    "",
+    true
+  )
+  .add_prototype("use_sum_second_order","")
+  .add_prototype("other","")
+  .add_prototype("","")
+
+  .add_parameter("other", ":py:class:`bob.learn.misc.PLDATrainer`", "A PLDATrainer object to be copied.")
+  .add_parameter("use_sum_second_order", "bool", "")
+);
+
+static int PyBobLearnMiscPLDATrainer_init_copy(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = PLDATrainer_doc.kwlist(1);
+  PyBobLearnMiscPLDATrainerObject* o;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscPLDATrainer_Type, &o)){
+    PLDATrainer_doc.print_usage();
+    return -1;
+  }
+
+  self->cxx.reset(new bob::learn::misc::PLDATrainer(*o->cxx));
+  return 0;
+}
+
+
+static int PyBobLearnMiscPLDATrainer_init_bool(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
+
+  char** kwlist = PLDATrainer_doc.kwlist(0);
+  PyObject* use_sum_second_order;
+
+  //Parsing the input argments
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBool_Type, &use_sum_second_order))
+    return -1;
+
+  self->cxx.reset(new bob::learn::misc::PLDATrainer(f(use_sum_second_order)));
+  return 0;
+}
+
+
+static int PyBobLearnMiscPLDATrainer_init(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  // get the number of command line arguments
+  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
+
+  switch(nargs){
+    case 0:{
+      self->cxx.reset(new bob::learn::misc::PLDATrainer());
+      return 0;
+    }
+    case 1:{
+      //Reading the input argument
+      PyObject* arg = 0;
+      if (PyTuple_Size(args))
+        arg = PyTuple_GET_ITEM(args, 0);
+      else {
+        PyObject* tmp = PyDict_Values(kwargs);
+        auto tmp_ = make_safe(tmp);
+        arg = PyList_GET_ITEM(tmp, 0);
+      }
+      
+      if(PyBobLearnMiscPLDATrainer_Check(arg))
+        // If the constructor input is PLDATrainer object
+        return PyBobLearnMiscPLDATrainer_init_copy(self, args, kwargs);
+      else
+        return PyBobLearnMiscPLDATrainer_init_bool(self, args, kwargs);
+
+    }
+    default:{
+      PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires only 0 or 1 argument, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
+      PLDATrainer_doc.print_usage();
+      return -1;
+    }
+  }
+  BOB_CATCH_MEMBER("cannot create PLDATrainer", 0)
+  return 0;
+}
+
+
+static void PyBobLearnMiscPLDATrainer_delete(PyBobLearnMiscPLDATrainerObject* self) {
+  self->cxx.reset();
+  Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+
+int PyBobLearnMiscPLDATrainer_Check(PyObject* o) {
+  return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscPLDATrainer_Type));
+}
+
+
+static PyObject* PyBobLearnMiscPLDATrainer_RichCompare(PyBobLearnMiscPLDATrainerObject* self, PyObject* other, int op) {
+  BOB_TRY
+
+  if (!PyBobLearnMiscPLDATrainer_Check(other)) {
+    PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name);
+    return 0;
+  }
+  auto other_ = reinterpret_cast<PyBobLearnMiscPLDATrainerObject*>(other);
+  switch (op) {
+    case Py_EQ:
+      if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE;
+    case Py_NE:
+      if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE;
+    default:
+      Py_INCREF(Py_NotImplemented);
+      return Py_NotImplemented;
+  }
+  BOB_CATCH_MEMBER("cannot compare PLDATrainer objects", 0)
+}
+
+
+/******************************************************************/
+/************ Variables Section ***********************************/
+/******************************************************************/
+
+static auto z_second_order = bob::extension::VariableDoc(
+  "z_second_order",
+  "array_like <float, 3D>",
+  "",
+  ""
+);
+PyObject* PyBobLearnMiscPLDATrainer_get_z_second_order(PyBobLearnMiscPLDATrainerObject* self, void*){
+  BOB_TRY
+  return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZSecondOrder());
+  BOB_CATCH_MEMBER("z_second_order could not be read", 0)
+}
+
+
+static auto z_second_order_sum = bob::extension::VariableDoc(
+  "z_second_order_sum",
+  "array_like <float, 2D>",
+  "",
+  ""
+);
+PyObject* PyBobLearnMiscPLDATrainer_get_z_second_order_sum(PyBobLearnMiscPLDATrainerObject* self, void*){
+  BOB_TRY
+  return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZSecondOrderSum());
+  BOB_CATCH_MEMBER("z_second_order_sum could not be read", 0)
+}
+
+
+static auto z_first_order = bob::extension::VariableDoc(
+  "z_first_order",
+  "array_like <float, 2D>",
+  "",
+  ""
+);
+PyObject* PyBobLearnMiscPLDATrainer_get_z_first_order(PyBobLearnMiscPLDATrainerObject* self, void*){
+  BOB_TRY
+  return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getZFirstOrder());
+  BOB_CATCH_MEMBER("z_first_order could not be read", 0)
+}
+
+
+
+
+static PyGetSetDef PyBobLearnMiscPLDATrainer_getseters[] = { 
+  {
+   z_first_order.name(),
+   (getter)PyBobLearnMiscPLDATrainer_get_z_first_order,
+   0,
+   z_first_order.doc(),
+   0
+  },
+  {
+   z_second_order_sum.name(),
+   (getter)PyBobLearnMiscPLDATrainer_get_z_second_order_sum,
+   0,
+   z_second_order_sum.doc(),
+   0
+  },
+  {
+   z_second_order.name(),
+   (getter)PyBobLearnMiscPLDATrainer_get_z_second_order,
+   0,
+   z_second_order.doc(),
+   0
+  },
+
+  {0}  // Sentinel
+};
+
+
+/******************************************************************/
+/************ Functions Section ***********************************/
+/******************************************************************/
+
+/*** initialize ***/
+static auto initialize = bob::extension::FunctionDoc(
+  "initialize",
+  "Initialization before the EM steps",
+  "",
+  true
+)
+.add_prototype("plda_base,data")
+.add_parameter("plda_base", ":py:class:`bob.learn.misc.PLDABase`", "PLDAMachine Object")
+.add_parameter("data", "list", "");
+static PyObject* PyBobLearnMiscPLDATrainer_initialize(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = initialize.kwlist(0);
+
+  PyBobLearnMiscPLDABaseObject* plda_base = 0;
+  PyObject* data = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnMiscPLDABase_Type, &plda_base,
+                                                                 &PyList_Type, &data)) Py_RETURN_NONE;
+
+  std::vector<blitz::Array<double,2> > data_vector;
+  if(list_as_vector(data ,data_vector)==0)
+    self->cxx->initialize(*plda_machine->cxx, data_vector);
+
+  BOB_CATCH_MEMBER("cannot perform the initialize method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+/*** e_step ***/
+static auto e_step = bob::extension::FunctionDoc(
+  "e_step",
+  "e_step before the EM steps",
+  "",
+  true
+)
+.add_prototype("plda_base,data")
+.add_parameter("plda_base", ":py:class:`bob.learn.misc.PLDABase`", "PLDAMachine Object")
+.add_parameter("data", "list", "");
+static PyObject* PyBobLearnMiscPLDATrainer_e_step(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = e_step.kwlist(0);
+
+  PyBobLearnMiscPLDABaseObject* plda_base = 0;
+  PyObject* data = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnMiscPLDABase_Type, &plda_base,
+                                                                 &PyList_Type, &data)) Py_RETURN_NONE;
+
+  std::vector<blitz::Array<double,2> > data_vector;
+  if(list_as_vector(data ,data_vector)==0)
+    self->cxx->e_step(*plda_machine->cxx, data_vector);
+
+  BOB_CATCH_MEMBER("cannot perform the e_step method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+/*** m_step ***/
+static auto m_step = bob::extension::FunctionDoc(
+  "m_step",
+  "m_step before the EM steps",
+  "",
+  true
+)
+.add_prototype("plda_base,data")
+.add_parameter("plda_base", ":py:class:`bob.learn.misc.PLDABase`", "PLDAMachine Object")
+.add_parameter("data", "list", "");
+static PyObject* PyBobLearnMiscPLDATrainer_m_step(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = m_step.kwlist(0);
+
+  PyBobLearnMiscPLDABaseObject* plda_base = 0;
+  PyObject* data = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnMiscPLDABase_Type, &plda_base,
+                                                                 &PyList_Type, &data)) Py_RETURN_NONE;
+
+  std::vector<blitz::Array<double,2> > data_vector;
+  if(list_as_vector(data ,data_vector)==0)
+    self->cxx->m_step(*plda_machine->cxx, data_vector);
+
+  BOB_CATCH_MEMBER("cannot perform the m_step method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+/*** finalize ***/
+static auto finalize = bob::extension::FunctionDoc(
+  "finalize",
+  "finalize before the EM steps",
+  "",
+  true
+)
+.add_prototype("plda_base,data")
+.add_parameter("plda_base", ":py:class:`bob.learn.misc.PLDABase`", "PLDAMachine Object")
+.add_parameter("data", "list", "");
+static PyObject* PyBobLearnMiscPLDATrainer_finalize(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = finalize.kwlist(0);
+
+  PyBobLearnMiscPLDABaseObject* plda_base = 0;
+  PyObject* data = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!", kwlist, &PyBobLearnMiscPLDABase_Type, &plda_base,
+                                                                 &PyList_Type, &data)) Py_RETURN_NONE;
+
+  std::vector<blitz::Array<double,2> > data_vector;
+  if(list_as_vector(data ,data_vector)==0)
+    self->cxx->finalize(*plda_machine->cxx, data_vector);
+
+  BOB_CATCH_MEMBER("cannot perform the finalize method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+
+/*** enrol ***/
+static auto enrol = bob::extension::FunctionDoc(
+  "enrol",
+  "Main procedure for enrolling a PLDAMachine",
+  "",
+  true
+)
+.add_prototype("plda_machine,data")
+.add_parameter("plda_machine", ":py:class:`bob.learn.misc.PLDAMachine`", "PLDAMachine Object")
+.add_parameter("data", "list", "");
+static PyObject* PyBobLearnMiscPLDATrainer_finalize(PyBobLearnMiscPLDATrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  /* Parses input arguments in a single shot */
+  char** kwlist = finalize.kwlist(0);
+
+  PyBobLearnMiscPLDAMachineObject* plda_machine = 0;
+  PyBlitzArrayObject* data = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscPLDAMachine_Type, &plda_machine,
+                                                                 &PyBlitzArray_Converter, &data)) Py_RETURN_NONE;
+
+  auto data_ = make_safe(data);
+  self->cxx->enrol(*plda_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data));
+
+  BOB_CATCH_MEMBER("cannot perform the enrol method", 0)
+
+  Py_RETURN_NONE;
+}
+
+
+static PyMethodDef PyBobLearnMiscPLDATrainer_methods[] = {
+  {
+    initialize.name(),
+    (PyCFunction)PyBobLearnMiscPLDATrainer_initialize,
+    METH_VARARGS|METH_KEYWORDS,
+    initialize.doc()
+  },
+  {
+    e_step.name(),
+    (PyCFunction)PyBobLearnMiscPLDATrainer_e_step,
+    METH_VARARGS|METH_KEYWORDS,
+    e_step.doc()
+  },
+  {
+    m_step.name(),
+    (PyCFunction)PyBobLearnMiscPLDATrainer_m_step,
+    METH_VARARGS|METH_KEYWORDS,
+    m_step.doc()
+  },
+  {
+    enrol.name(),
+    (PyCFunction)PyBobLearnMiscPLDATrainer_enrol,
+    METH_VARARGS|METH_KEYWORDS,
+    enrol.doc()
+  },
+  {0} /* Sentinel */
+};
+
+
+/******************************************************************/
+/************ Module Section **************************************/
+/******************************************************************/
+
+// Define the Gaussian type struct; will be initialized later
+PyTypeObject PyBobLearnMiscPLDATrainer_Type = {
+  PyVarObject_HEAD_INIT(0,0)
+  0
+};
+
+bool init_BobLearnMiscPLDATrainer(PyObject* module)
+{
+  // initialize the type struct
+  PyBobLearnMiscPLDATrainer_Type.tp_name      = PLDATrainer_doc.name();
+  PyBobLearnMiscPLDATrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscPLDATrainerObject);
+  PyBobLearnMiscPLDATrainer_Type.tp_flags     = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance;
+  PyBobLearnMiscPLDATrainer_Type.tp_doc       = PLDATrainer_doc.doc();
+
+  // set the functions
+  PyBobLearnMiscPLDATrainer_Type.tp_new          = PyType_GenericNew;
+  PyBobLearnMiscPLDATrainer_Type.tp_init         = reinterpret_cast<initproc>(PyBobLearnMiscPLDATrainer_init);
+  PyBobLearnMiscPLDATrainer_Type.tp_dealloc      = reinterpret_cast<destructor>(PyBobLearnMiscPLDATrainer_delete);
+  PyBobLearnMiscPLDATrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscPLDATrainer_RichCompare);
+  PyBobLearnMiscPLDATrainer_Type.tp_methods      = PyBobLearnMiscPLDATrainer_methods;
+  PyBobLearnMiscPLDATrainer_Type.tp_getset       = PyBobLearnMiscPLDATrainer_getseters;
+  //PyBobLearnMiscPLDATrainer_Type.tp_call         = reinterpret_cast<ternaryfunc>(PyBobLearnMiscPLDATrainer_compute_likelihood);
+
+
+  // check that everything is fine
+  if (PyType_Ready(&PyBobLearnMiscPLDATrainer_Type) < 0) return false;
+
+  // add the type to the module
+  Py_INCREF(&PyBobLearnMiscPLDATrainer_Type);
+  return PyModule_AddObject(module, "_PLDATrainer", (PyObject*)&PyBobLearnMiscPLDATrainer_Type) >= 0;
+}
+
diff --git a/setup.py b/setup.py
index e084e8a..51e44e4 100644
--- a/setup.py
+++ b/setup.py
@@ -79,7 +79,7 @@ setup(
           "bob/learn/misc/cpp/KMeansTrainer.cpp",
           "bob/learn/misc/cpp/MAP_GMMTrainer.cpp",
           "bob/learn/misc/cpp/ML_GMMTrainer.cpp",
-          #"bob/learn/misc/cpp/PLDATrainer.cpp",
+          "bob/learn/misc/cpp/PLDATrainer.cpp",
         ],
         bob_packages = bob_packages,
         packages = packages,
@@ -133,6 +133,8 @@ setup(
 
           "bob/learn/misc/empca_trainer.cpp",
 
+          "bob/learn/misc/plda_trainer.cpp",
+
           "bob/learn/misc/ztnorm.cpp",
 
           "bob/learn/misc/main.cpp",
-- 
GitLab