From ed7ab2803a70e30f7a155386df9014196713cc91 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Mon, 2 Feb 2015 14:59:29 +0100
Subject: [PATCH] Binding JFATrainer

---
 bob/learn/misc/__init__.py         |   1 +
 bob/learn/misc/__jfa_trainer__.py  |  22 ++--
 bob/learn/misc/cpp/JFATrainer.cpp  |   2 -
 bob/learn/misc/gmm_machine.cpp     |  34 +++++-
 bob/learn/misc/jfa_trainer.cpp     | 180 ++++++++++++++++++++++++++---
 bob/learn/misc/test_jfa_trainer.py |   4 +-
 6 files changed, 205 insertions(+), 38 deletions(-)

diff --git a/bob/learn/misc/__init__.py b/bob/learn/misc/__init__.py
index 5e3a170..877a303 100644
--- a/bob/learn/misc/__init__.py
+++ b/bob/learn/misc/__init__.py
@@ -14,6 +14,7 @@ from .version import module as __version__
 from .__kmeans_trainer__ import *
 from .__ML_gmm_trainer__ import *
 from .__MAP_gmm_trainer__ import *
+from .__jfa_trainer__ import *
 
 
 def ztnorm_same_value(vect_a, vect_b):
diff --git a/bob/learn/misc/__jfa_trainer__.py b/bob/learn/misc/__jfa_trainer__.py
index a2e661b..ad803ad 100644
--- a/bob/learn/misc/__jfa_trainer__.py
+++ b/bob/learn/misc/__jfa_trainer__.py
@@ -35,21 +35,21 @@ class JFATrainer (_JFATrainer):
         The data to be trained
     """
     #V Subspace
-    for i in self._max_iterations:
-      self.eStep1(jfa_base, data)
-      self.mStep1(jfa_base, data)
+    for i in range(self._max_iterations):
+      self.e_step1(jfa_base, data)
+      self.m_step1(jfa_base, data)
     self.finalize1(jfa_base, data)
 
     #U subspace
-    for i in self._max_iterations:
-      self.eStep2(jfa_base, data)
-      self.mStep2(jfa_base, data)
+    for i in range(self._max_iterations):
+      self.e_step2(jfa_base, data)
+      self.m_step2(jfa_base, data)
     self.finalize2(jfa_base, data)
 
     # d subspace
-    for i in self._max_iterations:
-      self.eStep3(jfa_base, data)
-      self.mStep3(jfa_base, data)
+    for i in range(self._max_iterations):
+      self.e_step3(jfa_base, data)
+      self.m_step3(jfa_base, data)
     self.finalize3(jfa_base, data)
 
 
@@ -67,7 +67,5 @@ class JFATrainer (_JFATrainer):
     self.train_loop(jfa_base, data)
 
 
-  def enrol(self, jfa_base, data):
-
 # copy the documentation from the base class
-__doc__ = _KMeansTrainer.__doc__
+__doc__ = _JFATrainer.__doc__
diff --git a/bob/learn/misc/cpp/JFATrainer.cpp b/bob/learn/misc/cpp/JFATrainer.cpp
index 13dea3c..6b79da8 100644
--- a/bob/learn/misc/cpp/JFATrainer.cpp
+++ b/bob/learn/misc/cpp/JFATrainer.cpp
@@ -174,7 +174,6 @@ void bob::learn::misc::JFATrainer::train(bob::learn::misc::JFABase& machine,
 }
 */
 
-/*
 void bob::learn::misc::JFATrainer::enrol(bob::learn::misc::JFAMachine& machine,
   const std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> >& ar,
   const size_t n_iter)
@@ -198,5 +197,4 @@ void bob::learn::misc::JFATrainer::enrol(bob::learn::misc::JFAMachine& machine,
   machine.setY(y);
   machine.setZ(z);
 }
-*/
 
diff --git a/bob/learn/misc/gmm_machine.cpp b/bob/learn/misc/gmm_machine.cpp
index 9bea2a8..0b21d68 100644
--- a/bob/learn/misc/gmm_machine.cpp
+++ b/bob/learn/misc/gmm_machine.cpp
@@ -277,7 +277,20 @@ PyObject* PyBobLearnMiscGMMMachine_getVarianceSupervector(PyBobLearnMiscGMMMachi
   return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getVarianceSupervector());
   BOB_CATCH_MEMBER("variance_supervector could not be read", 0)
 }
-
+int PyBobLearnMiscGMMMachine_setVarianceSupervector(PyBobLearnMiscGMMMachineObject* self, PyObject* value, void*){
+  BOB_TRY
+  PyBlitzArrayObject* o;
+  if (!PyBlitzArray_Converter(value, &o)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, variance_supervector.name());
+    return -1;
+  }
+  auto o_ = make_safe(o);
+  auto b = PyBlitzArrayCxx_AsBlitz<double,1>(o, "variance_supervector");
+  if (!b) return -1;
+  self->cxx->setVarianceSupervector(*b);
+  return 0;
+  BOB_CATCH_MEMBER("variance_supervector could not be set", -1)
+}
 
 /***** mean_supervector *****/
 static auto mean_supervector = bob::extension::VariableDoc(
@@ -291,6 +304,21 @@ PyObject* PyBobLearnMiscGMMMachine_getMeanSupervector(PyBobLearnMiscGMMMachineOb
   return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getMeanSupervector());
   BOB_CATCH_MEMBER("mean_supervector could not be read", 0)
 }
+int PyBobLearnMiscGMMMachine_setMeanSupervector(PyBobLearnMiscGMMMachineObject* self, PyObject* value, void*){
+  BOB_TRY
+  PyBlitzArrayObject* o;
+  if (!PyBlitzArray_Converter(value, &o)){
+    PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, mean_supervector.name());
+    return -1;
+  }
+  auto o_ = make_safe(o);
+  auto b = PyBlitzArrayCxx_AsBlitz<double,1>(o, "mean_supervector");
+  if (!b) return -1;
+  self->cxx->setMeanSupervector(*b);
+  return 0;
+  BOB_CATCH_MEMBER("mean_supervector could not be set", -1)
+}
+
 
 
 /***** variance_thresholds *****/
@@ -362,7 +390,7 @@ static PyGetSetDef PyBobLearnMiscGMMMachine_getseters[] = {
   {
    variance_supervector.name(),
    (getter)PyBobLearnMiscGMMMachine_getVarianceSupervector,
-   0,
+   (setter)PyBobLearnMiscGMMMachine_setVarianceSupervector,
    variance_supervector.doc(),
    0
   },
@@ -370,7 +398,7 @@ static PyGetSetDef PyBobLearnMiscGMMMachine_getseters[] = {
   {
    mean_supervector.name(),
    (getter)PyBobLearnMiscGMMMachine_getMeanSupervector,
-   0,
+   (setter)PyBobLearnMiscGMMMachine_setMeanSupervector,
    mean_supervector.doc(),
    0
   },
diff --git a/bob/learn/misc/jfa_trainer.cpp b/bob/learn/misc/jfa_trainer.cpp
index 9735aee..cb4dded 100644
--- a/bob/learn/misc/jfa_trainer.cpp
+++ b/bob/learn/misc/jfa_trainer.cpp
@@ -14,17 +14,34 @@
 /************ Constructor Section *********************************/
 /******************************************************************/
 
-static int extract_GMMStats(PyObject *list,
-                             std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& training_data)
+static int extract_GMMStats_1d(PyObject *list,
+                             std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> >& training_data)
 {
+  std::cout << " #################### "  << std::endl;
+  std::cout << PyList_GET_SIZE(list) << std::endl;
+
+  for (int i=0; i<PyList_GET_SIZE(list); i++){
+  
+    PyBobLearnMiscGMMStatsObject* stats;
+    if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){
+      PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects");
+      return -1;
+    }
+    training_data.push_back(stats->cxx);
+  }
+  return 0;
+}
 
-  for (size_t i=0; i<PyList_GET_SIZE(list); i++)
+static int extract_GMMStats_2d(PyObject *list,
+                             std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& training_data)
+{
+  for (int i=0; i<PyList_GET_SIZE(list); i++)
   {
     PyObject* another_list;
     PyArg_Parse(PyList_GetItem(list, i), "O!", &PyList_Type, &another_list);
 
     std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > another_training_data;
-    for (size_t j=0; j<PyList_GET_SIZE(another_list); j++){
+    for (int j=0; j<PyList_GET_SIZE(another_list); j++){
 
       PyBobLearnMiscGMMStatsObject* stats;
       if (!PyArg_Parse(PyList_GetItem(another_list, j), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){
@@ -50,6 +67,22 @@ static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec)
   return list;
 }
 
+template <int N>
+int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
+{
+  for (int i=0; i<PyList_GET_SIZE(list); i++)
+  {
+    PyBlitzArrayObject* blitz_object; 
+    if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){
+      PyErr_Format(PyExc_RuntimeError, "Expected numpy array object");
+      return -1;
+    }
+    auto blitz_object_ = make_safe(blitz_object);
+    vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object));
+  }
+  return 0;
+}
+
 
 
 static auto JFATrainer_doc = bob::extension::ClassDoc(
@@ -319,6 +352,24 @@ PyObject* PyBobLearnMiscJFATrainer_get_X(PyBobLearnMiscJFATrainerObject* self, v
   return vector_as_list(self->cxx->getX());
   BOB_CATCH_MEMBER("__X__ could not be read", 0)
 }
+int PyBobLearnMiscJFATrainer_set_X(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){
+  BOB_TRY
+
+  // Parses input arguments in a single shot
+  if (!PyList_Check(value)){
+    PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __X__.name());
+    return -1;
+  }
+    
+  std::vector<blitz::Array<double,2> > data;
+  if(list_as_vector(value ,data)==0){
+    self->cxx->setX(data);
+  }
+    
+  return 0;
+  BOB_CATCH_MEMBER("__X__ could not be written", 0)
+}
+
 
 
 static auto __Y__ = bob::extension::VariableDoc(
@@ -332,6 +383,24 @@ PyObject* PyBobLearnMiscJFATrainer_get_Y(PyBobLearnMiscJFATrainerObject* self, v
   return vector_as_list(self->cxx->getY());
   BOB_CATCH_MEMBER("__Y__ could not be read", 0)
 }
+int PyBobLearnMiscJFATrainer_set_Y(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){
+  BOB_TRY
+
+  // Parses input arguments in a single shot
+  if (!PyList_Check(value)){
+    PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __Y__.name());
+    return -1;
+  }
+    
+  std::vector<blitz::Array<double,1> > data;
+  if(list_as_vector(value ,data)==0){
+    self->cxx->setY(data);
+  }
+    
+  return 0;
+  BOB_CATCH_MEMBER("__Y__ could not be written", 0)
+}
+
 
 
 static auto __Z__ = bob::extension::VariableDoc(
@@ -345,6 +414,23 @@ PyObject* PyBobLearnMiscJFATrainer_get_Z(PyBobLearnMiscJFATrainerObject* self, v
   return vector_as_list(self->cxx->getZ());
   BOB_CATCH_MEMBER("__Z__ could not be read", 0)
 }
+int PyBobLearnMiscJFATrainer_set_Z(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){
+  BOB_TRY
+
+  // Parses input arguments in a single shot
+  if (!PyList_Check(value)){
+    PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __Z__.name());
+    return -1;
+  }
+    
+  std::vector<blitz::Array<double,1> > data;
+  if(list_as_vector(value ,data)==0){
+    self->cxx->setZ(data);
+  }
+    
+  return 0;
+  BOB_CATCH_MEMBER("__Z__ could not be written", 0)
+}
 
 
 
@@ -435,21 +521,21 @@ static PyGetSetDef PyBobLearnMiscJFATrainer_getseters[] = {
   {
    __X__.name(),
    (getter)PyBobLearnMiscJFATrainer_get_X,
-   0,
+   (setter)PyBobLearnMiscJFATrainer_set_X,
    __X__.doc(),
    0
   },
   {
    __Y__.name(),
    (getter)PyBobLearnMiscJFATrainer_get_Y,
-   0,
+   (setter)PyBobLearnMiscJFATrainer_set_Y,
    __Y__.doc(),
    0
   },
   {
    __Z__.name(),
    (getter)PyBobLearnMiscJFATrainer_get_Z,
-   0,
+   (setter)PyBobLearnMiscJFATrainer_set_Z,
    __Z__.doc(),
    0
   },
@@ -487,7 +573,7 @@ static PyObject* PyBobLearnMiscJFATrainer_initialize(PyBobLearnMiscJFATrainerObj
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->initialize(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the initialize method", 0)
@@ -519,7 +605,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step1(PyBobLearnMiscJFATrainerObject
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->eStep1(*jfa_base->cxx, training_data);
 
 
@@ -552,7 +638,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step1(PyBobLearnMiscJFATrainerObject
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->mStep1(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the m_step1 method", 0)
@@ -584,7 +670,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize1(PyBobLearnMiscJFATrainerObje
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->finalize1(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the finalize1 method", 0)
@@ -616,7 +702,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step2(PyBobLearnMiscJFATrainerObject
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->eStep2(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the e_step2 method", 0)
@@ -648,7 +734,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step2(PyBobLearnMiscJFATrainerObject
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->mStep2(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the m_step2 method", 0)
@@ -680,7 +766,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize2(PyBobLearnMiscJFATrainerObje
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->finalize2(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the finalize2 method", 0)
@@ -712,7 +798,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step3(PyBobLearnMiscJFATrainerObject
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->eStep3(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the e_step3 method", 0)
@@ -744,7 +830,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step3(PyBobLearnMiscJFATrainerObject
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->mStep3(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the m_step3 method", 0)
@@ -776,7 +862,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize3(PyBobLearnMiscJFATrainerObje
                                                                  &PyList_Type, &stats)) Py_RETURN_NONE;
 
   std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
-  if(extract_GMMStats(stats ,training_data)==0)
+  if(extract_GMMStats_2d(stats ,training_data)==0)
     self->cxx->finalize3(*jfa_base->cxx, training_data);
 
   BOB_CATCH_MEMBER("cannot perform the finalize3 method", 0)
@@ -785,6 +871,39 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize3(PyBobLearnMiscJFATrainerObje
 }
 
 
+/*** enrol ***/
+static auto enrol = bob::extension::FunctionDoc(
+  "enrol",
+  "",
+  "",
+  true
+)
+.add_prototype("jfa_machine,features, n_inter")
+.add_parameter("jfa_machine", ":py:class:`bob.learn.misc.JFAMachine`", "JFAMachine Object")
+.add_parameter("features", "list(:py:class:`bob.learn.misc.GMMStats`)`", "")
+.add_parameter("n_iter", "int", "Number of iterations");
+static PyObject* PyBobLearnMiscJFATrainer_enrol(PyBobLearnMiscJFATrainerObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+
+  // Parses input arguments in a single shot
+  char** kwlist = enrol.kwlist(0);
+
+  PyBobLearnMiscJFAMachineObject* jfa_machine = 0;
+  PyObject* stats = 0;
+  int n_iter = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!d", kwlist, &PyBobLearnMiscJFAMachine_Type, &jfa_machine,
+                                                                 &PyList_Type, &stats, &n_iter)) Py_RETURN_NONE;
+
+  std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > training_data;
+  if(extract_GMMStats_1d(stats ,training_data)==0)
+    self->cxx->enrol(*jfa_machine->cxx, training_data, n_iter);
+
+  BOB_CATCH_MEMBER("cannot perform the enrol method", 0)
+
+  Py_RETURN_NONE;
+}
+
 
 
 static PyMethodDef PyBobLearnMiscJFATrainer_methods[] = {
@@ -830,7 +949,30 @@ static PyMethodDef PyBobLearnMiscJFATrainer_methods[] = {
     METH_VARARGS|METH_KEYWORDS,
     m_step3.doc()
   },
-
+  {
+    finalize1.name(),
+    (PyCFunction)PyBobLearnMiscJFATrainer_finalize1,
+    METH_VARARGS|METH_KEYWORDS,
+    finalize1.doc()
+  },
+  {
+    finalize2.name(),
+    (PyCFunction)PyBobLearnMiscJFATrainer_finalize2,
+    METH_VARARGS|METH_KEYWORDS,
+    finalize2.doc()
+  },
+  {
+    finalize3.name(),
+    (PyCFunction)PyBobLearnMiscJFATrainer_finalize3,
+    METH_VARARGS|METH_KEYWORDS,
+    finalize3.doc()
+  },
+  {
+    enrol.name(),
+    (PyCFunction)PyBobLearnMiscJFATrainer_enrol,
+    METH_VARARGS|METH_KEYWORDS,
+    enrol.doc()
+  },
   {0} /* Sentinel */
 };
 
@@ -850,7 +992,7 @@ bool init_BobLearnMiscJFATrainer(PyObject* module)
   // initialize the type struct
   PyBobLearnMiscJFATrainer_Type.tp_name      = JFATrainer_doc.name();
   PyBobLearnMiscJFATrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscJFATrainerObject);
-  PyBobLearnMiscJFATrainer_Type.tp_flags     = Py_TPFLAGS_DEFAULT;
+  PyBobLearnMiscJFATrainer_Type.tp_flags     = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance;
   PyBobLearnMiscJFATrainer_Type.tp_doc       = JFATrainer_doc.doc();
 
   // set the functions
diff --git a/bob/learn/misc/test_jfa_trainer.py b/bob/learn/misc/test_jfa_trainer.py
index 7c3f4d4..c094a6d 100644
--- a/bob/learn/misc/test_jfa_trainer.py
+++ b/bob/learn/misc/test_jfa_trainer.py
@@ -13,8 +13,8 @@ import numpy.linalg
 
 import bob.core.random
 
-from . import GMMStats, GMMMachine, JFABase, JFAMachine, ISVBase, ISVMachine, \
-    JFATrainer, ISVTrainer
+from . import GMMStats, GMMMachine, JFABase, JFAMachine, ISVBase, ISVMachine, JFATrainer
+#, ISVTrainer
 
 
 def equals(x, y, epsilon):
-- 
GitLab