From 8bcad55d9889b0bb6498a3d891a63d16b72e9d45 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 11 Dec 2014 16:53:46 +0100
Subject: [PATCH] Modifications in the load and save bind in order to support
 keyword arguments

---
 bob/learn/misc/gmm_stats.cpp | 162 ++++++++++++++++++-----------------
 1 file changed, 84 insertions(+), 78 deletions(-)

diff --git a/bob/learn/misc/gmm_stats.cpp b/bob/learn/misc/gmm_stats.cpp
index a5c7e5b..9a9b754 100644
--- a/bob/learn/misc/gmm_stats.cpp
+++ b/bob/learn/misc/gmm_stats.cpp
@@ -15,12 +15,11 @@
 
 static auto GMMStats_doc = bob::extension::ClassDoc(
   BOB_EXT_MODULE_PREFIX ".GMMStats",
-  "A container for GMM statistics.\n",
-  "With respect to Reynolds, \"Speaker Verification Using Adapted "
-  "Gaussian Mixture Models\", DSP, 2000:\n"
-  "Eq (8) is n(i)\n"
-  "Eq (9) is sumPx(i) / n(i)\n"
-  "Eq (10) is sumPxx(i) / n(i)\n"
+  "A container for GMM statistics",
+  "With respect to [Reynolds2000]_ the class computes: \n\n"
+  "* Eq (8) is :math:`n_i=\\sum\\limits_{t=1}^T Pr(i | x_t)`\n\n"
+  "* Eq (9) is :math:`sumPx=E_i(x)=\\frac{1}{n(i)}\\sum\\limits_{t=1}^T Pr(i | x_t)x_t`\n\n"
+  "* Eq (10) is :math:`sumPxx=E_i(x^2)=\\frac{1}{n(i)}\\sum\\limits_{t=1}^T Pr(i | x_t)x_t^2`\n\n"
 ).add_constructor(
   bob::extension::FunctionDoc(
     "__init__",
@@ -52,11 +51,13 @@ static int PyBobLearnMiscGMMStats_init_number(PyBobLearnMiscGMMStatsObject* self
 
   if(n_gaussians < 0){
     PyErr_Format(PyExc_TypeError, "gaussians argument must be greater than or equal to zero");
+    GMMStats_doc.print_usage();
     return -1;
   }
 
   if(n_inputs < 0){
     PyErr_Format(PyExc_TypeError, "input argument must be greater than or equal to zero");
+    GMMStats_doc.print_usage();
     return -1;
    }
 
@@ -69,7 +70,10 @@ static int PyBobLearnMiscGMMStats_init_copy(PyBobLearnMiscGMMStatsObject* self,
 
   char** kwlist = GMMStats_doc.kwlist(1);
   PyBobLearnMiscGMMStatsObject* tt;
-  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMStats_Type, &tt)) return -1;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMStats_Type, &tt)){
+    GMMStats_doc.print_usage();
+    return -1;
+  }
 
   self->cxx.reset(new bob::learn::misc::GMMStats(*tt->cxx));
   return 0;
@@ -81,21 +85,13 @@ static int PyBobLearnMiscGMMStats_init_hdf5(PyBobLearnMiscGMMStatsObject* self,
   char** kwlist = GMMStats_doc.kwlist(2);
 
   PyBobIoHDF5FileObject* config = 0;
-  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBobIoHDF5File_Converter, &config))
-    return -1;
-
-  try {
-    self->cxx.reset(new bob::learn::misc::GMMStats(*(config->f)));
-  }
-  catch (std::exception& ex) {
-    PyErr_SetString(PyExc_RuntimeError, ex.what());
-    return -1;
-  }
-  catch (...) {
-    PyErr_Format(PyExc_RuntimeError, "cannot create new object of type `%s' - unknown exception thrown", Py_TYPE(self)->tp_name);
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBobIoHDF5File_Converter, &config)){
+    GMMStats_doc.print_usage();
     return -1;
   }
 
+  self->cxx.reset(new bob::learn::misc::GMMStats(*(config->f)));
+
   return 0;
 }
 
@@ -105,12 +101,13 @@ static int PyBobLearnMiscGMMStats_init(PyBobLearnMiscGMMStatsObject* self, PyObj
   BOB_TRY
 
   // get the number of command line arguments
-  Py_ssize_t nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
+  int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
 
   switch (nargs) {
 
     case 0: //default initializer ()
       self->cxx.reset(new bob::learn::misc::GMMStats());
+      return 0;
 
     case 1:{
       //Reading the input argument
@@ -133,7 +130,8 @@ static int PyBobLearnMiscGMMStats_init(PyBobLearnMiscGMMStatsObject* self, PyObj
     case 2:
       return PyBobLearnMiscGMMStats_init_number(self, args, kwargs);
     default:
-      PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 0, 1 or 2 arguments, but you provided %" PY_FORMAT_SIZE_T "d (see help)", Py_TYPE(self)->tp_name, nargs);
+      PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 0, 1 or 2 arguments, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
+      GMMStats_doc.print_usage();
       return -1;
   }
   BOB_CATCH_MEMBER("cannot create GMMStats", 0)
@@ -179,8 +177,8 @@ int PyBobLearnMiscGMMStats_Check(PyObject* o) {
 /***** n *****/
 static auto n = bob::extension::VariableDoc(
   "n",
-  "array_like <double, 1D> ",
-  "For each Gaussian, the accumulated sum of responsibilities, i.e. the sum of P(gaussian_i|x)"
+  "array_like <double, 1D>",
+  "For each Gaussian, the accumulated sum of responsibilities, i.e. the sum of :math:`P(gaussian_i|x)`"
 );
 PyObject* PyBobLearnMiscGMMStats_getN(PyBobLearnMiscGMMStatsObject* self, void*){
   BOB_TRY
@@ -260,8 +258,8 @@ int PyBobLearnMiscGMMStats_setSum_pxx(PyBobLearnMiscGMMStatsObject* self, PyObje
 /***** t *****/
 static auto t = bob::extension::VariableDoc(
   "t",
-  "int ",
-  "The accumulated number of samples"
+  "int",
+  "The number of samples"
 );
 PyObject* PyBobLearnMiscGMMStats_getT(PyBobLearnMiscGMMStatsObject* self, void*){
   BOB_TRY
@@ -312,6 +310,21 @@ int PyBobLearnMiscGMMStats_setLog_likelihood(PyBobLearnMiscGMMStatsObject* self,
 }
 
 
+/***** shape *****/
+static auto shape = bob::extension::VariableDoc(
+  "shape",
+  "(int,int)",
+  "A tuple that represents the number of gaussians and dimensionality of each Gaussian ``(n_gaussians, dim)``.",
+  ""
+);
+PyObject* PyBobLearnMiscGMMStats_getShape(PyBobLearnMiscGMMStatsObject* self, void*) {
+  BOB_TRY
+  return Py_BuildValue("(i,i)", self->cxx->sumPx.shape()[0], self->cxx->sumPx.shape()[1]);
+  BOB_CATCH_MEMBER("shape could not be read", 0)
+}
+
+
+
 static PyGetSetDef PyBobLearnMiscGMMStats_getseters[] = {
   {
     n.name(),
@@ -347,8 +360,16 @@ static PyGetSetDef PyBobLearnMiscGMMStats_getseters[] = {
     (setter)PyBobLearnMiscGMMStats_setLog_likelihood,
     log_likelihood.doc(),
     0
+  },  
+  {
+   shape.name(),
+   (getter)PyBobLearnMiscGMMStats_getShape,
+   0,
+   shape.doc(),
+   0
   },
 
+
   {0}  // Sentinel
 };
 
@@ -364,31 +385,20 @@ static auto save = bob::extension::FunctionDoc(
   "Save the configuration of the GMMStats to a given HDF5 file"
 )
 .add_prototype("hdf5")
-.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for writing")
-;
-static PyObject* PyBobLearnMiscGMMStats_Save(PyBobLearnMiscGMMStatsObject* self, PyObject* arg) {
+.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for writing");
+static PyObject* PyBobLearnMiscGMMStats_Save(PyBobLearnMiscGMMStatsObject* self,  PyObject* args, PyObject* kwargs) {
 
+  BOB_TRY
+  
   // get list of arguments
-  if (!PyBobIoHDF5File_Check(arg)) {
-    PyErr_Format(PyExc_TypeError, "`%s' cannot write itself to `%s', only to an HDF5 file", Py_TYPE(self)->tp_name, Py_TYPE(arg)->tp_name);
-    return 0;
-  }
+  char** kwlist = save.kwlist(0);  
+  PyBobIoHDF5FileObject* hdf5;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, PyBobIoHDF5File_Converter, &hdf5)) return 0;
 
-  auto hdf5 = reinterpret_cast<PyBobIoHDF5FileObject*>(arg);
-
-  try {
-    self->cxx->save(*hdf5->f);
-  }
-  catch (std::exception& e) {
-    PyErr_SetString(PyExc_RuntimeError, e.what());
-    return 0;
-  }
-  catch (...) {
-    PyErr_Format(PyExc_RuntimeError, "`%s' cannot write data to file `%s' (at group `%s'): unknown exception caught", Py_TYPE(self)->tp_name,
-        hdf5->f->filename().c_str(), hdf5->f->cwd().c_str());
-    return 0;
-  }
+  auto hdf5_ = make_safe(hdf5);
+  self->cxx->save(*hdf5->f);
 
+  BOB_CATCH_MEMBER("cannot save the data", 0)
   Py_RETURN_NONE;
 }
 
@@ -399,27 +409,17 @@ static auto load = bob::extension::FunctionDoc(
 )
 .add_prototype("hdf5")
 .add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for reading");
-static PyObject* PyBobLearnMiscGMMStats_Load(PyBobLearnMiscGMMStatsObject* self, PyObject* f) {
-
-  if (!PyBobIoHDF5File_Check(f)) {
-    PyErr_Format(PyExc_TypeError, "`%s' cannot load itself from `%s', only from an HDF5 file", Py_TYPE(self)->tp_name, Py_TYPE(f)->tp_name);
-    return 0;
-  }
-
-  auto h5f = reinterpret_cast<PyBobIoHDF5FileObject*>(f);
+static PyObject* PyBobLearnMiscGMMStats_Load(PyBobLearnMiscGMMStatsObject* self, PyObject* args, PyObject* kwargs) {
+  BOB_TRY
+  
+  char** kwlist = load.kwlist(0);  
+  PyBobIoHDF5FileObject* hdf5;
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, PyBobIoHDF5File_Converter, &hdf5)) return 0;
+  
+  auto hdf5_ = make_safe(hdf5);  
+  self->cxx->load(*hdf5->f);
 
-  try {
-    self->cxx->load(*h5f->f);
-  }
-  catch (std::exception& e) {
-    PyErr_SetString(PyExc_RuntimeError, e.what());
-    return 0;
-  }
-  catch (...) {
-    PyErr_Format(PyExc_RuntimeError, "cannot read data from file `%s' (at group `%s'): unknown exception caught", h5f->f->filename().c_str(),
-        h5f->f->cwd().c_str());
-    return 0;
-  }
+  BOB_CATCH_MEMBER("cannot load the data", 0)
   Py_RETURN_NONE;
 }
 
@@ -429,14 +429,14 @@ static auto is_similar_to = bob::extension::FunctionDoc(
   "is_similar_to",
   
   "Compares this GMMStats with the ``other`` one to be approximately the same.",
-  "The optional values ``r_epsilon`` and ``a_epsilon`` refer to the"
-  "relative and absolute precision for the ``weights``, ``biases``"
+  "The optional values ``r_epsilon`` and ``a_epsilon`` refer to the "
+  "relative and absolute precision for the ``weights``, ``biases`` "
   "and any other values internal to this machine."
 )
 .add_prototype("other, [r_epsilon], [a_epsilon]","output")
 .add_parameter("other", ":py:class:`bob.learn.misc.GMMStats`", "A GMMStats object to be compared.")
-.add_parameter("[r_epsilon]", "float", "Relative precision.")
-.add_parameter("[a_epsilon]", "float", "Absolute precision.")
+.add_parameter("r_epsilon", "float", "Relative precision.")
+.add_parameter("a_epsilon", "float", "Absolute precision.")
 .add_return("output","bool","True if it is similar, otherwise false.");
 static PyObject* PyBobLearnMiscGMMStats_IsSimilarTo(PyBobLearnMiscGMMStatsObject* self, PyObject* args, PyObject* kwds) {
 
@@ -450,9 +450,11 @@ static PyObject* PyBobLearnMiscGMMStats_IsSimilarTo(PyBobLearnMiscGMMStatsObject
 
   if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|dd", kwlist,
         &PyBobLearnMiscGMMStats_Type, &other,
-        &r_epsilon, &a_epsilon)) return 0;
+        &r_epsilon, &a_epsilon)){
 
-  //auto other_ = reinterpret_cast<PyBobLearnMiscGMMStatsObject*>(other);
+        is_similar_to.print_usage(); 
+        return 0;        
+  }
 
   if (self->cxx->is_similar_to(*other->cxx, r_epsilon, a_epsilon))
     Py_RETURN_TRUE;
@@ -464,9 +466,11 @@ static PyObject* PyBobLearnMiscGMMStats_IsSimilarTo(PyBobLearnMiscGMMStatsObject
 /*** resize ***/
 static auto resize = bob::extension::FunctionDoc(
   "resize",
-  " Allocates space for the statistics and resets to zero."
+  "Allocates space for the statistics and resets to zero.",
+  0,
+  true
 )
-.add_prototype("n_gaussians,n_inputs","")
+.add_prototype("n_gaussians,n_inputs")
 .add_parameter("n_gaussians", "int", "Number of gaussians")
 .add_parameter("n_inputs", "int", "Dimensionality of the feature vector");
 static PyObject* PyBobLearnMiscGMMStats_resize(PyBobLearnMiscGMMStatsObject* self, PyObject* args, PyObject* kwargs) {
@@ -482,10 +486,12 @@ static PyObject* PyBobLearnMiscGMMStats_resize(PyBobLearnMiscGMMStatsObject* sel
 
   if (n_gaussians <= 0){
     PyErr_Format(PyExc_TypeError, "n_gaussians must be greater than zero");
+    resize.print_usage();
     return 0;
   }
   if (n_inputs <= 0){
     PyErr_Format(PyExc_TypeError, "n_inputs must be greater than zero");
+    resize.print_usage();
     return 0;
   }
 
@@ -503,7 +509,7 @@ static auto init = bob::extension::FunctionDoc(
   "init",
   " Resets statistics to zero."
 )
-.add_prototype("","");
+.add_prototype("");
 static PyObject* PyBobLearnMiscGMMStats_init_method(PyBobLearnMiscGMMStatsObject* self) {
   BOB_TRY
 
@@ -520,13 +526,13 @@ static PyMethodDef PyBobLearnMiscGMMStats_methods[] = {
   {
     save.name(),
     (PyCFunction)PyBobLearnMiscGMMStats_Save,
-    METH_O,
+    METH_VARARGS|METH_KEYWORDS,
     save.doc()
   },
   {
     load.name(),
     (PyCFunction)PyBobLearnMiscGMMStats_Load,
-    METH_O,
+    METH_VARARGS|METH_KEYWORDS,
     load.doc()
   },
   {
@@ -601,7 +607,7 @@ bool init_BobLearnMiscGMMStats(PyObject* module)
   PyBobLearnMiscGMMStats_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscGMMStats_RichCompare);
   PyBobLearnMiscGMMStats_Type.tp_methods = PyBobLearnMiscGMMStats_methods;
   PyBobLearnMiscGMMStats_Type.tp_getset = PyBobLearnMiscGMMStats_getseters;
-  //PyBobLearnMiscGMMStats_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscGMMStats_loglikelihood);
+  PyBobLearnMiscGMMStats_Type.tp_call = 0;
   PyBobLearnMiscGMMStats_Type.tp_as_number = &PyBobLearnMiscGMMStats_operators;
 
   //set operators
-- 
GitLab