From ed7ab2803a70e30f7a155386df9014196713cc91 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Mon, 2 Feb 2015 14:59:29 +0100 Subject: [PATCH] Binding JFATrainer --- bob/learn/misc/__init__.py | 1 + bob/learn/misc/__jfa_trainer__.py | 22 ++-- bob/learn/misc/cpp/JFATrainer.cpp | 2 - bob/learn/misc/gmm_machine.cpp | 34 +++++- bob/learn/misc/jfa_trainer.cpp | 180 ++++++++++++++++++++++++++--- bob/learn/misc/test_jfa_trainer.py | 4 +- 6 files changed, 205 insertions(+), 38 deletions(-) diff --git a/bob/learn/misc/__init__.py b/bob/learn/misc/__init__.py index 5e3a170..877a303 100644 --- a/bob/learn/misc/__init__.py +++ b/bob/learn/misc/__init__.py @@ -14,6 +14,7 @@ from .version import module as __version__ from .__kmeans_trainer__ import * from .__ML_gmm_trainer__ import * from .__MAP_gmm_trainer__ import * +from .__jfa_trainer__ import * def ztnorm_same_value(vect_a, vect_b): diff --git a/bob/learn/misc/__jfa_trainer__.py b/bob/learn/misc/__jfa_trainer__.py index a2e661b..ad803ad 100644 --- a/bob/learn/misc/__jfa_trainer__.py +++ b/bob/learn/misc/__jfa_trainer__.py @@ -35,21 +35,21 @@ class JFATrainer (_JFATrainer): The data to be trained """ #V Subspace - for i in self._max_iterations: - self.eStep1(jfa_base, data) - self.mStep1(jfa_base, data) + for i in range(self._max_iterations): + self.e_step1(jfa_base, data) + self.m_step1(jfa_base, data) self.finalize1(jfa_base, data) #U subspace - for i in self._max_iterations: - self.eStep2(jfa_base, data) - self.mStep2(jfa_base, data) + for i in range(self._max_iterations): + self.e_step2(jfa_base, data) + self.m_step2(jfa_base, data) self.finalize2(jfa_base, data) # d subspace - for i in self._max_iterations: - self.eStep3(jfa_base, data) - self.mStep3(jfa_base, data) + for i in range(self._max_iterations): + self.e_step3(jfa_base, data) + self.m_step3(jfa_base, data) self.finalize3(jfa_base, data) @@ -67,7 +67,5 @@ class JFATrainer (_JFATrainer): self.train_loop(jfa_base, data) - def enrol(self, jfa_base, data): - # copy the documentation from the base class -__doc__ = _KMeansTrainer.__doc__ +__doc__ = _JFATrainer.__doc__ diff --git a/bob/learn/misc/cpp/JFATrainer.cpp b/bob/learn/misc/cpp/JFATrainer.cpp index 13dea3c..6b79da8 100644 --- a/bob/learn/misc/cpp/JFATrainer.cpp +++ b/bob/learn/misc/cpp/JFATrainer.cpp @@ -174,7 +174,6 @@ void bob::learn::misc::JFATrainer::train(bob::learn::misc::JFABase& machine, } */ -/* void bob::learn::misc::JFATrainer::enrol(bob::learn::misc::JFAMachine& machine, const std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> >& ar, const size_t n_iter) @@ -198,5 +197,4 @@ void bob::learn::misc::JFATrainer::enrol(bob::learn::misc::JFAMachine& machine, machine.setY(y); machine.setZ(z); } -*/ diff --git a/bob/learn/misc/gmm_machine.cpp b/bob/learn/misc/gmm_machine.cpp index 9bea2a8..0b21d68 100644 --- a/bob/learn/misc/gmm_machine.cpp +++ b/bob/learn/misc/gmm_machine.cpp @@ -277,7 +277,20 @@ PyObject* PyBobLearnMiscGMMMachine_getVarianceSupervector(PyBobLearnMiscGMMMachi return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getVarianceSupervector()); BOB_CATCH_MEMBER("variance_supervector could not be read", 0) } - +int PyBobLearnMiscGMMMachine_setVarianceSupervector(PyBobLearnMiscGMMMachineObject* self, PyObject* value, void*){ + BOB_TRY + PyBlitzArrayObject* o; + if (!PyBlitzArray_Converter(value, &o)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, variance_supervector.name()); + return -1; + } + auto o_ = make_safe(o); + auto b = PyBlitzArrayCxx_AsBlitz<double,1>(o, "variance_supervector"); + if (!b) return -1; + self->cxx->setVarianceSupervector(*b); + return 0; + BOB_CATCH_MEMBER("variance_supervector could not be set", -1) +} /***** mean_supervector *****/ static auto mean_supervector = bob::extension::VariableDoc( @@ -291,6 +304,21 @@ PyObject* PyBobLearnMiscGMMMachine_getMeanSupervector(PyBobLearnMiscGMMMachineOb return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getMeanSupervector()); BOB_CATCH_MEMBER("mean_supervector could not be read", 0) } +int PyBobLearnMiscGMMMachine_setMeanSupervector(PyBobLearnMiscGMMMachineObject* self, PyObject* value, void*){ + BOB_TRY + PyBlitzArrayObject* o; + if (!PyBlitzArray_Converter(value, &o)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, mean_supervector.name()); + return -1; + } + auto o_ = make_safe(o); + auto b = PyBlitzArrayCxx_AsBlitz<double,1>(o, "mean_supervector"); + if (!b) return -1; + self->cxx->setMeanSupervector(*b); + return 0; + BOB_CATCH_MEMBER("mean_supervector could not be set", -1) +} + /***** variance_thresholds *****/ @@ -362,7 +390,7 @@ static PyGetSetDef PyBobLearnMiscGMMMachine_getseters[] = { { variance_supervector.name(), (getter)PyBobLearnMiscGMMMachine_getVarianceSupervector, - 0, + (setter)PyBobLearnMiscGMMMachine_setVarianceSupervector, variance_supervector.doc(), 0 }, @@ -370,7 +398,7 @@ static PyGetSetDef PyBobLearnMiscGMMMachine_getseters[] = { { mean_supervector.name(), (getter)PyBobLearnMiscGMMMachine_getMeanSupervector, - 0, + (setter)PyBobLearnMiscGMMMachine_setMeanSupervector, mean_supervector.doc(), 0 }, diff --git a/bob/learn/misc/jfa_trainer.cpp b/bob/learn/misc/jfa_trainer.cpp index 9735aee..cb4dded 100644 --- a/bob/learn/misc/jfa_trainer.cpp +++ b/bob/learn/misc/jfa_trainer.cpp @@ -14,17 +14,34 @@ /************ Constructor Section *********************************/ /******************************************************************/ -static int extract_GMMStats(PyObject *list, - std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& training_data) +static int extract_GMMStats_1d(PyObject *list, + std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> >& training_data) { + std::cout << " #################### " << std::endl; + std::cout << PyList_GET_SIZE(list) << std::endl; + + for (int i=0; i<PyList_GET_SIZE(list); i++){ + + PyBobLearnMiscGMMStatsObject* stats; + if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){ + PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects"); + return -1; + } + training_data.push_back(stats->cxx); + } + return 0; +} - for (size_t i=0; i<PyList_GET_SIZE(list); i++) +static int extract_GMMStats_2d(PyObject *list, + std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& training_data) +{ + for (int i=0; i<PyList_GET_SIZE(list); i++) { PyObject* another_list; PyArg_Parse(PyList_GetItem(list, i), "O!", &PyList_Type, &another_list); std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > another_training_data; - for (size_t j=0; j<PyList_GET_SIZE(another_list); j++){ + for (int j=0; j<PyList_GET_SIZE(another_list); j++){ PyBobLearnMiscGMMStatsObject* stats; if (!PyArg_Parse(PyList_GetItem(another_list, j), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){ @@ -50,6 +67,22 @@ static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec) return list; } +template <int N> +int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec) +{ + for (int i=0; i<PyList_GET_SIZE(list); i++) + { + PyBlitzArrayObject* blitz_object; + if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){ + PyErr_Format(PyExc_RuntimeError, "Expected numpy array object"); + return -1; + } + auto blitz_object_ = make_safe(blitz_object); + vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object)); + } + return 0; +} + static auto JFATrainer_doc = bob::extension::ClassDoc( @@ -319,6 +352,24 @@ PyObject* PyBobLearnMiscJFATrainer_get_X(PyBobLearnMiscJFATrainerObject* self, v return vector_as_list(self->cxx->getX()); BOB_CATCH_MEMBER("__X__ could not be read", 0) } +int PyBobLearnMiscJFATrainer_set_X(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){ + BOB_TRY + + // Parses input arguments in a single shot + if (!PyList_Check(value)){ + PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __X__.name()); + return -1; + } + + std::vector<blitz::Array<double,2> > data; + if(list_as_vector(value ,data)==0){ + self->cxx->setX(data); + } + + return 0; + BOB_CATCH_MEMBER("__X__ could not be written", 0) +} + static auto __Y__ = bob::extension::VariableDoc( @@ -332,6 +383,24 @@ PyObject* PyBobLearnMiscJFATrainer_get_Y(PyBobLearnMiscJFATrainerObject* self, v return vector_as_list(self->cxx->getY()); BOB_CATCH_MEMBER("__Y__ could not be read", 0) } +int PyBobLearnMiscJFATrainer_set_Y(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){ + BOB_TRY + + // Parses input arguments in a single shot + if (!PyList_Check(value)){ + PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __Y__.name()); + return -1; + } + + std::vector<blitz::Array<double,1> > data; + if(list_as_vector(value ,data)==0){ + self->cxx->setY(data); + } + + return 0; + BOB_CATCH_MEMBER("__Y__ could not be written", 0) +} + static auto __Z__ = bob::extension::VariableDoc( @@ -345,6 +414,23 @@ PyObject* PyBobLearnMiscJFATrainer_get_Z(PyBobLearnMiscJFATrainerObject* self, v return vector_as_list(self->cxx->getZ()); BOB_CATCH_MEMBER("__Z__ could not be read", 0) } +int PyBobLearnMiscJFATrainer_set_Z(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){ + BOB_TRY + + // Parses input arguments in a single shot + if (!PyList_Check(value)){ + PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __Z__.name()); + return -1; + } + + std::vector<blitz::Array<double,1> > data; + if(list_as_vector(value ,data)==0){ + self->cxx->setZ(data); + } + + return 0; + BOB_CATCH_MEMBER("__Z__ could not be written", 0) +} @@ -435,21 +521,21 @@ static PyGetSetDef PyBobLearnMiscJFATrainer_getseters[] = { { __X__.name(), (getter)PyBobLearnMiscJFATrainer_get_X, - 0, + (setter)PyBobLearnMiscJFATrainer_set_X, __X__.doc(), 0 }, { __Y__.name(), (getter)PyBobLearnMiscJFATrainer_get_Y, - 0, + (setter)PyBobLearnMiscJFATrainer_set_Y, __Y__.doc(), 0 }, { __Z__.name(), (getter)PyBobLearnMiscJFATrainer_get_Z, - 0, + (setter)PyBobLearnMiscJFATrainer_set_Z, __Z__.doc(), 0 }, @@ -487,7 +573,7 @@ static PyObject* PyBobLearnMiscJFATrainer_initialize(PyBobLearnMiscJFATrainerObj &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->initialize(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the initialize method", 0) @@ -519,7 +605,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step1(PyBobLearnMiscJFATrainerObject &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->eStep1(*jfa_base->cxx, training_data); @@ -552,7 +638,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step1(PyBobLearnMiscJFATrainerObject &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->mStep1(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the m_step1 method", 0) @@ -584,7 +670,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize1(PyBobLearnMiscJFATrainerObje &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->finalize1(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the finalize1 method", 0) @@ -616,7 +702,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step2(PyBobLearnMiscJFATrainerObject &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->eStep2(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the e_step2 method", 0) @@ -648,7 +734,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step2(PyBobLearnMiscJFATrainerObject &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->mStep2(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the m_step2 method", 0) @@ -680,7 +766,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize2(PyBobLearnMiscJFATrainerObje &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->finalize2(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the finalize2 method", 0) @@ -712,7 +798,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step3(PyBobLearnMiscJFATrainerObject &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->eStep3(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the e_step3 method", 0) @@ -744,7 +830,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step3(PyBobLearnMiscJFATrainerObject &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->mStep3(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the m_step3 method", 0) @@ -776,7 +862,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize3(PyBobLearnMiscJFATrainerObje &PyList_Type, &stats)) Py_RETURN_NONE; std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data; - if(extract_GMMStats(stats ,training_data)==0) + if(extract_GMMStats_2d(stats ,training_data)==0) self->cxx->finalize3(*jfa_base->cxx, training_data); BOB_CATCH_MEMBER("cannot perform the finalize3 method", 0) @@ -785,6 +871,39 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize3(PyBobLearnMiscJFATrainerObje } +/*** enrol ***/ +static auto enrol = bob::extension::FunctionDoc( + "enrol", + "", + "", + true +) +.add_prototype("jfa_machine,features, n_inter") +.add_parameter("jfa_machine", ":py:class:`bob.learn.misc.JFAMachine`", "JFAMachine Object") +.add_parameter("features", "list(:py:class:`bob.learn.misc.GMMStats`)`", "") +.add_parameter("n_iter", "int", "Number of iterations"); +static PyObject* PyBobLearnMiscJFATrainer_enrol(PyBobLearnMiscJFATrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + // Parses input arguments in a single shot + char** kwlist = enrol.kwlist(0); + + PyBobLearnMiscJFAMachineObject* jfa_machine = 0; + PyObject* stats = 0; + int n_iter = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!d", kwlist, &PyBobLearnMiscJFAMachine_Type, &jfa_machine, + &PyList_Type, &stats, &n_iter)) Py_RETURN_NONE; + + std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > training_data; + if(extract_GMMStats_1d(stats ,training_data)==0) + self->cxx->enrol(*jfa_machine->cxx, training_data, n_iter); + + BOB_CATCH_MEMBER("cannot perform the enrol method", 0) + + Py_RETURN_NONE; +} + static PyMethodDef PyBobLearnMiscJFATrainer_methods[] = { @@ -830,7 +949,30 @@ static PyMethodDef PyBobLearnMiscJFATrainer_methods[] = { METH_VARARGS|METH_KEYWORDS, m_step3.doc() }, - + { + finalize1.name(), + (PyCFunction)PyBobLearnMiscJFATrainer_finalize1, + METH_VARARGS|METH_KEYWORDS, + finalize1.doc() + }, + { + finalize2.name(), + (PyCFunction)PyBobLearnMiscJFATrainer_finalize2, + METH_VARARGS|METH_KEYWORDS, + finalize2.doc() + }, + { + finalize3.name(), + (PyCFunction)PyBobLearnMiscJFATrainer_finalize3, + METH_VARARGS|METH_KEYWORDS, + finalize3.doc() + }, + { + enrol.name(), + (PyCFunction)PyBobLearnMiscJFATrainer_enrol, + METH_VARARGS|METH_KEYWORDS, + enrol.doc() + }, {0} /* Sentinel */ }; @@ -850,7 +992,7 @@ bool init_BobLearnMiscJFATrainer(PyObject* module) // initialize the type struct PyBobLearnMiscJFATrainer_Type.tp_name = JFATrainer_doc.name(); PyBobLearnMiscJFATrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscJFATrainerObject); - PyBobLearnMiscJFATrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT; + PyBobLearnMiscJFATrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance; PyBobLearnMiscJFATrainer_Type.tp_doc = JFATrainer_doc.doc(); // set the functions diff --git a/bob/learn/misc/test_jfa_trainer.py b/bob/learn/misc/test_jfa_trainer.py index 7c3f4d4..c094a6d 100644 --- a/bob/learn/misc/test_jfa_trainer.py +++ b/bob/learn/misc/test_jfa_trainer.py @@ -13,8 +13,8 @@ import numpy.linalg import bob.core.random -from . import GMMStats, GMMMachine, JFABase, JFAMachine, ISVBase, ISVMachine, \ - JFATrainer, ISVTrainer +from . import GMMStats, GMMMachine, JFABase, JFAMachine, ISVBase, ISVMachine, JFATrainer +#, ISVTrainer def equals(x, y, epsilon): -- GitLab