Skip to content
Snippets Groups Projects
Commit 8bcad55d authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Modifications in the load and save bind in order to support keyword arguments

parent 6e1006fe
No related branches found
No related tags found
No related merge requests found
......@@ -15,12 +15,11 @@
static auto GMMStats_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".GMMStats",
"A container for GMM statistics.\n",
"With respect to Reynolds, \"Speaker Verification Using Adapted "
"Gaussian Mixture Models\", DSP, 2000:\n"
"Eq (8) is n(i)\n"
"Eq (9) is sumPx(i) / n(i)\n"
"Eq (10) is sumPxx(i) / n(i)\n"
"A container for GMM statistics",
"With respect to [Reynolds2000]_ the class computes: \n\n"
"* Eq (8) is :math:`n_i=\\sum\\limits_{t=1}^T Pr(i | x_t)`\n\n"
"* Eq (9) is :math:`sumPx=E_i(x)=\\frac{1}{n(i)}\\sum\\limits_{t=1}^T Pr(i | x_t)x_t`\n\n"
"* Eq (10) is :math:`sumPxx=E_i(x^2)=\\frac{1}{n(i)}\\sum\\limits_{t=1}^T Pr(i | x_t)x_t^2`\n\n"
).add_constructor(
bob::extension::FunctionDoc(
"__init__",
......@@ -52,11 +51,13 @@ static int PyBobLearnMiscGMMStats_init_number(PyBobLearnMiscGMMStatsObject* self
if(n_gaussians < 0){
PyErr_Format(PyExc_TypeError, "gaussians argument must be greater than or equal to zero");
GMMStats_doc.print_usage();
return -1;
}
if(n_inputs < 0){
PyErr_Format(PyExc_TypeError, "input argument must be greater than or equal to zero");
GMMStats_doc.print_usage();
return -1;
}
......@@ -69,7 +70,10 @@ static int PyBobLearnMiscGMMStats_init_copy(PyBobLearnMiscGMMStatsObject* self,
char** kwlist = GMMStats_doc.kwlist(1);
PyBobLearnMiscGMMStatsObject* tt;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMStats_Type, &tt)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMStats_Type, &tt)){
GMMStats_doc.print_usage();
return -1;
}
self->cxx.reset(new bob::learn::misc::GMMStats(*tt->cxx));
return 0;
......@@ -81,21 +85,13 @@ static int PyBobLearnMiscGMMStats_init_hdf5(PyBobLearnMiscGMMStatsObject* self,
char** kwlist = GMMStats_doc.kwlist(2);
PyBobIoHDF5FileObject* config = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBobIoHDF5File_Converter, &config))
return -1;
try {
self->cxx.reset(new bob::learn::misc::GMMStats(*(config->f)));
}
catch (std::exception& ex) {
PyErr_SetString(PyExc_RuntimeError, ex.what());
return -1;
}
catch (...) {
PyErr_Format(PyExc_RuntimeError, "cannot create new object of type `%s' - unknown exception thrown", Py_TYPE(self)->tp_name);
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, &PyBobIoHDF5File_Converter, &config)){
GMMStats_doc.print_usage();
return -1;
}
self->cxx.reset(new bob::learn::misc::GMMStats(*(config->f)));
return 0;
}
......@@ -105,12 +101,13 @@ static int PyBobLearnMiscGMMStats_init(PyBobLearnMiscGMMStatsObject* self, PyObj
BOB_TRY
// get the number of command line arguments
Py_ssize_t nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
switch (nargs) {
case 0: //default initializer ()
self->cxx.reset(new bob::learn::misc::GMMStats());
return 0;
case 1:{
//Reading the input argument
......@@ -133,7 +130,8 @@ static int PyBobLearnMiscGMMStats_init(PyBobLearnMiscGMMStatsObject* self, PyObj
case 2:
return PyBobLearnMiscGMMStats_init_number(self, args, kwargs);
default:
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 0, 1 or 2 arguments, but you provided %" PY_FORMAT_SIZE_T "d (see help)", Py_TYPE(self)->tp_name, nargs);
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 0, 1 or 2 arguments, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
GMMStats_doc.print_usage();
return -1;
}
BOB_CATCH_MEMBER("cannot create GMMStats", 0)
......@@ -179,8 +177,8 @@ int PyBobLearnMiscGMMStats_Check(PyObject* o) {
/***** n *****/
static auto n = bob::extension::VariableDoc(
"n",
"array_like <double, 1D> ",
"For each Gaussian, the accumulated sum of responsibilities, i.e. the sum of P(gaussian_i|x)"
"array_like <double, 1D>",
"For each Gaussian, the accumulated sum of responsibilities, i.e. the sum of :math:`P(gaussian_i|x)`"
);
PyObject* PyBobLearnMiscGMMStats_getN(PyBobLearnMiscGMMStatsObject* self, void*){
BOB_TRY
......@@ -260,8 +258,8 @@ int PyBobLearnMiscGMMStats_setSum_pxx(PyBobLearnMiscGMMStatsObject* self, PyObje
/***** t *****/
static auto t = bob::extension::VariableDoc(
"t",
"int ",
"The accumulated number of samples"
"int",
"The number of samples"
);
PyObject* PyBobLearnMiscGMMStats_getT(PyBobLearnMiscGMMStatsObject* self, void*){
BOB_TRY
......@@ -312,6 +310,21 @@ int PyBobLearnMiscGMMStats_setLog_likelihood(PyBobLearnMiscGMMStatsObject* self,
}
/***** shape *****/
static auto shape = bob::extension::VariableDoc(
"shape",
"(int,int)",
"A tuple that represents the number of gaussians and dimensionality of each Gaussian ``(n_gaussians, dim)``.",
""
);
PyObject* PyBobLearnMiscGMMStats_getShape(PyBobLearnMiscGMMStatsObject* self, void*) {
BOB_TRY
return Py_BuildValue("(i,i)", self->cxx->sumPx.shape()[0], self->cxx->sumPx.shape()[1]);
BOB_CATCH_MEMBER("shape could not be read", 0)
}
static PyGetSetDef PyBobLearnMiscGMMStats_getseters[] = {
{
n.name(),
......@@ -347,8 +360,16 @@ static PyGetSetDef PyBobLearnMiscGMMStats_getseters[] = {
(setter)PyBobLearnMiscGMMStats_setLog_likelihood,
log_likelihood.doc(),
0
},
{
shape.name(),
(getter)PyBobLearnMiscGMMStats_getShape,
0,
shape.doc(),
0
},
{0} // Sentinel
};
......@@ -364,31 +385,20 @@ static auto save = bob::extension::FunctionDoc(
"Save the configuration of the GMMStats to a given HDF5 file"
)
.add_prototype("hdf5")
.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for writing")
;
static PyObject* PyBobLearnMiscGMMStats_Save(PyBobLearnMiscGMMStatsObject* self, PyObject* arg) {
.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for writing");
static PyObject* PyBobLearnMiscGMMStats_Save(PyBobLearnMiscGMMStatsObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
// get list of arguments
if (!PyBobIoHDF5File_Check(arg)) {
PyErr_Format(PyExc_TypeError, "`%s' cannot write itself to `%s', only to an HDF5 file", Py_TYPE(self)->tp_name, Py_TYPE(arg)->tp_name);
return 0;
}
char** kwlist = save.kwlist(0);
PyBobIoHDF5FileObject* hdf5;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, PyBobIoHDF5File_Converter, &hdf5)) return 0;
auto hdf5 = reinterpret_cast<PyBobIoHDF5FileObject*>(arg);
try {
self->cxx->save(*hdf5->f);
}
catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return 0;
}
catch (...) {
PyErr_Format(PyExc_RuntimeError, "`%s' cannot write data to file `%s' (at group `%s'): unknown exception caught", Py_TYPE(self)->tp_name,
hdf5->f->filename().c_str(), hdf5->f->cwd().c_str());
return 0;
}
auto hdf5_ = make_safe(hdf5);
self->cxx->save(*hdf5->f);
BOB_CATCH_MEMBER("cannot save the data", 0)
Py_RETURN_NONE;
}
......@@ -399,27 +409,17 @@ static auto load = bob::extension::FunctionDoc(
)
.add_prototype("hdf5")
.add_parameter("hdf5", ":py:class:`bob.io.base.HDF5File`", "An HDF5 file open for reading");
static PyObject* PyBobLearnMiscGMMStats_Load(PyBobLearnMiscGMMStatsObject* self, PyObject* f) {
if (!PyBobIoHDF5File_Check(f)) {
PyErr_Format(PyExc_TypeError, "`%s' cannot load itself from `%s', only from an HDF5 file", Py_TYPE(self)->tp_name, Py_TYPE(f)->tp_name);
return 0;
}
auto h5f = reinterpret_cast<PyBobIoHDF5FileObject*>(f);
static PyObject* PyBobLearnMiscGMMStats_Load(PyBobLearnMiscGMMStatsObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
char** kwlist = load.kwlist(0);
PyBobIoHDF5FileObject* hdf5;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist, PyBobIoHDF5File_Converter, &hdf5)) return 0;
auto hdf5_ = make_safe(hdf5);
self->cxx->load(*hdf5->f);
try {
self->cxx->load(*h5f->f);
}
catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return 0;
}
catch (...) {
PyErr_Format(PyExc_RuntimeError, "cannot read data from file `%s' (at group `%s'): unknown exception caught", h5f->f->filename().c_str(),
h5f->f->cwd().c_str());
return 0;
}
BOB_CATCH_MEMBER("cannot load the data", 0)
Py_RETURN_NONE;
}
......@@ -429,14 +429,14 @@ static auto is_similar_to = bob::extension::FunctionDoc(
"is_similar_to",
"Compares this GMMStats with the ``other`` one to be approximately the same.",
"The optional values ``r_epsilon`` and ``a_epsilon`` refer to the"
"relative and absolute precision for the ``weights``, ``biases``"
"The optional values ``r_epsilon`` and ``a_epsilon`` refer to the "
"relative and absolute precision for the ``weights``, ``biases`` "
"and any other values internal to this machine."
)
.add_prototype("other, [r_epsilon], [a_epsilon]","output")
.add_parameter("other", ":py:class:`bob.learn.misc.GMMStats`", "A GMMStats object to be compared.")
.add_parameter("[r_epsilon]", "float", "Relative precision.")
.add_parameter("[a_epsilon]", "float", "Absolute precision.")
.add_parameter("r_epsilon", "float", "Relative precision.")
.add_parameter("a_epsilon", "float", "Absolute precision.")
.add_return("output","bool","True if it is similar, otherwise false.");
static PyObject* PyBobLearnMiscGMMStats_IsSimilarTo(PyBobLearnMiscGMMStatsObject* self, PyObject* args, PyObject* kwds) {
......@@ -450,9 +450,11 @@ static PyObject* PyBobLearnMiscGMMStats_IsSimilarTo(PyBobLearnMiscGMMStatsObject
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|dd", kwlist,
&PyBobLearnMiscGMMStats_Type, &other,
&r_epsilon, &a_epsilon)) return 0;
&r_epsilon, &a_epsilon)){
//auto other_ = reinterpret_cast<PyBobLearnMiscGMMStatsObject*>(other);
is_similar_to.print_usage();
return 0;
}
if (self->cxx->is_similar_to(*other->cxx, r_epsilon, a_epsilon))
Py_RETURN_TRUE;
......@@ -464,9 +466,11 @@ static PyObject* PyBobLearnMiscGMMStats_IsSimilarTo(PyBobLearnMiscGMMStatsObject
/*** resize ***/
static auto resize = bob::extension::FunctionDoc(
"resize",
" Allocates space for the statistics and resets to zero."
"Allocates space for the statistics and resets to zero.",
0,
true
)
.add_prototype("n_gaussians,n_inputs","")
.add_prototype("n_gaussians,n_inputs")
.add_parameter("n_gaussians", "int", "Number of gaussians")
.add_parameter("n_inputs", "int", "Dimensionality of the feature vector");
static PyObject* PyBobLearnMiscGMMStats_resize(PyBobLearnMiscGMMStatsObject* self, PyObject* args, PyObject* kwargs) {
......@@ -482,10 +486,12 @@ static PyObject* PyBobLearnMiscGMMStats_resize(PyBobLearnMiscGMMStatsObject* sel
if (n_gaussians <= 0){
PyErr_Format(PyExc_TypeError, "n_gaussians must be greater than zero");
resize.print_usage();
return 0;
}
if (n_inputs <= 0){
PyErr_Format(PyExc_TypeError, "n_inputs must be greater than zero");
resize.print_usage();
return 0;
}
......@@ -503,7 +509,7 @@ static auto init = bob::extension::FunctionDoc(
"init",
" Resets statistics to zero."
)
.add_prototype("","");
.add_prototype("");
static PyObject* PyBobLearnMiscGMMStats_init_method(PyBobLearnMiscGMMStatsObject* self) {
BOB_TRY
......@@ -520,13 +526,13 @@ static PyMethodDef PyBobLearnMiscGMMStats_methods[] = {
{
save.name(),
(PyCFunction)PyBobLearnMiscGMMStats_Save,
METH_O,
METH_VARARGS|METH_KEYWORDS,
save.doc()
},
{
load.name(),
(PyCFunction)PyBobLearnMiscGMMStats_Load,
METH_O,
METH_VARARGS|METH_KEYWORDS,
load.doc()
},
{
......@@ -601,7 +607,7 @@ bool init_BobLearnMiscGMMStats(PyObject* module)
PyBobLearnMiscGMMStats_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscGMMStats_RichCompare);
PyBobLearnMiscGMMStats_Type.tp_methods = PyBobLearnMiscGMMStats_methods;
PyBobLearnMiscGMMStats_Type.tp_getset = PyBobLearnMiscGMMStats_getseters;
//PyBobLearnMiscGMMStats_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscGMMStats_loglikelihood);
PyBobLearnMiscGMMStats_Type.tp_call = 0;
PyBobLearnMiscGMMStats_Type.tp_as_number = &PyBobLearnMiscGMMStats_operators;
//set operators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment