Commit ed7ab280 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Binding JFATrainer

parent 53f4259a
......@@ -14,6 +14,7 @@ from .version import module as __version__
from .__kmeans_trainer__ import *
from .__ML_gmm_trainer__ import *
from .__MAP_gmm_trainer__ import *
from .__jfa_trainer__ import *
def ztnorm_same_value(vect_a, vect_b):
......
......@@ -35,21 +35,21 @@ class JFATrainer (_JFATrainer):
The data to be trained
"""
#V Subspace
for i in self._max_iterations:
self.eStep1(jfa_base, data)
self.mStep1(jfa_base, data)
for i in range(self._max_iterations):
self.e_step1(jfa_base, data)
self.m_step1(jfa_base, data)
self.finalize1(jfa_base, data)
#U subspace
for i in self._max_iterations:
self.eStep2(jfa_base, data)
self.mStep2(jfa_base, data)
for i in range(self._max_iterations):
self.e_step2(jfa_base, data)
self.m_step2(jfa_base, data)
self.finalize2(jfa_base, data)
# d subspace
for i in self._max_iterations:
self.eStep3(jfa_base, data)
self.mStep3(jfa_base, data)
for i in range(self._max_iterations):
self.e_step3(jfa_base, data)
self.m_step3(jfa_base, data)
self.finalize3(jfa_base, data)
......@@ -67,7 +67,5 @@ class JFATrainer (_JFATrainer):
self.train_loop(jfa_base, data)
def enrol(self, jfa_base, data):
# copy the documentation from the base class
__doc__ = _KMeansTrainer.__doc__
__doc__ = _JFATrainer.__doc__
......@@ -174,7 +174,6 @@ void bob::learn::misc::JFATrainer::train(bob::learn::misc::JFABase& machine,
}
*/
/*
void bob::learn::misc::JFATrainer::enrol(bob::learn::misc::JFAMachine& machine,
const std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> >& ar,
const size_t n_iter)
......@@ -198,5 +197,4 @@ void bob::learn::misc::JFATrainer::enrol(bob::learn::misc::JFAMachine& machine,
machine.setY(y);
machine.setZ(z);
}
*/
......@@ -277,7 +277,20 @@ PyObject* PyBobLearnMiscGMMMachine_getVarianceSupervector(PyBobLearnMiscGMMMachi
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getVarianceSupervector());
BOB_CATCH_MEMBER("variance_supervector could not be read", 0)
}
int PyBobLearnMiscGMMMachine_setVarianceSupervector(PyBobLearnMiscGMMMachineObject* self, PyObject* value, void*){
BOB_TRY
PyBlitzArrayObject* o;
if (!PyBlitzArray_Converter(value, &o)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, variance_supervector.name());
return -1;
}
auto o_ = make_safe(o);
auto b = PyBlitzArrayCxx_AsBlitz<double,1>(o, "variance_supervector");
if (!b) return -1;
self->cxx->setVarianceSupervector(*b);
return 0;
BOB_CATCH_MEMBER("variance_supervector could not be set", -1)
}
/***** mean_supervector *****/
static auto mean_supervector = bob::extension::VariableDoc(
......@@ -291,6 +304,21 @@ PyObject* PyBobLearnMiscGMMMachine_getMeanSupervector(PyBobLearnMiscGMMMachineOb
return PyBlitzArrayCxx_AsConstNumpy(self->cxx->getMeanSupervector());
BOB_CATCH_MEMBER("mean_supervector could not be read", 0)
}
int PyBobLearnMiscGMMMachine_setMeanSupervector(PyBobLearnMiscGMMMachineObject* self, PyObject* value, void*){
BOB_TRY
PyBlitzArrayObject* o;
if (!PyBlitzArray_Converter(value, &o)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a 1D array of floats", Py_TYPE(self)->tp_name, mean_supervector.name());
return -1;
}
auto o_ = make_safe(o);
auto b = PyBlitzArrayCxx_AsBlitz<double,1>(o, "mean_supervector");
if (!b) return -1;
self->cxx->setMeanSupervector(*b);
return 0;
BOB_CATCH_MEMBER("mean_supervector could not be set", -1)
}
/***** variance_thresholds *****/
......@@ -362,7 +390,7 @@ static PyGetSetDef PyBobLearnMiscGMMMachine_getseters[] = {
{
variance_supervector.name(),
(getter)PyBobLearnMiscGMMMachine_getVarianceSupervector,
0,
(setter)PyBobLearnMiscGMMMachine_setVarianceSupervector,
variance_supervector.doc(),
0
},
......@@ -370,7 +398,7 @@ static PyGetSetDef PyBobLearnMiscGMMMachine_getseters[] = {
{
mean_supervector.name(),
(getter)PyBobLearnMiscGMMMachine_getMeanSupervector,
0,
(setter)PyBobLearnMiscGMMMachine_setMeanSupervector,
mean_supervector.doc(),
0
},
......
......@@ -14,17 +14,34 @@
/************ Constructor Section *********************************/
/******************************************************************/
static int extract_GMMStats(PyObject *list,
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& training_data)
static int extract_GMMStats_1d(PyObject *list,
std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> >& training_data)
{
std::cout << " #################### " << std::endl;
std::cout << PyList_GET_SIZE(list) << std::endl;
for (int i=0; i<PyList_GET_SIZE(list); i++){
PyBobLearnMiscGMMStatsObject* stats;
if (!PyArg_Parse(PyList_GetItem(list, i), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){
PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects");
return -1;
}
training_data.push_back(stats->cxx);
}
return 0;
}
for (size_t i=0; i<PyList_GET_SIZE(list); i++)
static int extract_GMMStats_2d(PyObject *list,
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > >& training_data)
{
for (int i=0; i<PyList_GET_SIZE(list); i++)
{
PyObject* another_list;
PyArg_Parse(PyList_GetItem(list, i), "O!", &PyList_Type, &another_list);
std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > another_training_data;
for (size_t j=0; j<PyList_GET_SIZE(another_list); j++){
for (int j=0; j<PyList_GET_SIZE(another_list); j++){
PyBobLearnMiscGMMStatsObject* stats;
if (!PyArg_Parse(PyList_GetItem(another_list, j), "O!", &PyBobLearnMiscGMMStats_Type, &stats)){
......@@ -50,6 +67,22 @@ static PyObject* vector_as_list(const std::vector<blitz::Array<double,N> >& vec)
return list;
}
template <int N>
int list_as_vector(PyObject* list, std::vector<blitz::Array<double,N> >& vec)
{
for (int i=0; i<PyList_GET_SIZE(list); i++)
{
PyBlitzArrayObject* blitz_object;
if (!PyArg_Parse(PyList_GetItem(list, i), "O&", &PyBlitzArray_Converter, &blitz_object)){
PyErr_Format(PyExc_RuntimeError, "Expected numpy array object");
return -1;
}
auto blitz_object_ = make_safe(blitz_object);
vec.push_back(*PyBlitzArrayCxx_AsBlitz<double,N>(blitz_object));
}
return 0;
}
static auto JFATrainer_doc = bob::extension::ClassDoc(
......@@ -319,6 +352,24 @@ PyObject* PyBobLearnMiscJFATrainer_get_X(PyBobLearnMiscJFATrainerObject* self, v
return vector_as_list(self->cxx->getX());
BOB_CATCH_MEMBER("__X__ could not be read", 0)
}
int PyBobLearnMiscJFATrainer_set_X(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){
BOB_TRY
// Parses input arguments in a single shot
if (!PyList_Check(value)){
PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __X__.name());
return -1;
}
std::vector<blitz::Array<double,2> > data;
if(list_as_vector(value ,data)==0){
self->cxx->setX(data);
}
return 0;
BOB_CATCH_MEMBER("__X__ could not be written", 0)
}
static auto __Y__ = bob::extension::VariableDoc(
......@@ -332,6 +383,24 @@ PyObject* PyBobLearnMiscJFATrainer_get_Y(PyBobLearnMiscJFATrainerObject* self, v
return vector_as_list(self->cxx->getY());
BOB_CATCH_MEMBER("__Y__ could not be read", 0)
}
int PyBobLearnMiscJFATrainer_set_Y(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){
BOB_TRY
// Parses input arguments in a single shot
if (!PyList_Check(value)){
PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __Y__.name());
return -1;
}
std::vector<blitz::Array<double,1> > data;
if(list_as_vector(value ,data)==0){
self->cxx->setY(data);
}
return 0;
BOB_CATCH_MEMBER("__Y__ could not be written", 0)
}
static auto __Z__ = bob::extension::VariableDoc(
......@@ -345,6 +414,23 @@ PyObject* PyBobLearnMiscJFATrainer_get_Z(PyBobLearnMiscJFATrainerObject* self, v
return vector_as_list(self->cxx->getZ());
BOB_CATCH_MEMBER("__Z__ could not be read", 0)
}
int PyBobLearnMiscJFATrainer_set_Z(PyBobLearnMiscJFATrainerObject* self, PyObject* value, void*){
BOB_TRY
// Parses input arguments in a single shot
if (!PyList_Check(value)){
PyErr_Format(PyExc_TypeError, "Expected a list in `%s'", __Z__.name());
return -1;
}
std::vector<blitz::Array<double,1> > data;
if(list_as_vector(value ,data)==0){
self->cxx->setZ(data);
}
return 0;
BOB_CATCH_MEMBER("__Z__ could not be written", 0)
}
......@@ -435,21 +521,21 @@ static PyGetSetDef PyBobLearnMiscJFATrainer_getseters[] = {
{
__X__.name(),
(getter)PyBobLearnMiscJFATrainer_get_X,
0,
(setter)PyBobLearnMiscJFATrainer_set_X,
__X__.doc(),
0
},
{
__Y__.name(),
(getter)PyBobLearnMiscJFATrainer_get_Y,
0,
(setter)PyBobLearnMiscJFATrainer_set_Y,
__Y__.doc(),
0
},
{
__Z__.name(),
(getter)PyBobLearnMiscJFATrainer_get_Z,
0,
(setter)PyBobLearnMiscJFATrainer_set_Z,
__Z__.doc(),
0
},
......@@ -487,7 +573,7 @@ static PyObject* PyBobLearnMiscJFATrainer_initialize(PyBobLearnMiscJFATrainerObj
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->initialize(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the initialize method", 0)
......@@ -519,7 +605,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step1(PyBobLearnMiscJFATrainerObject
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->eStep1(*jfa_base->cxx, training_data);
......@@ -552,7 +638,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step1(PyBobLearnMiscJFATrainerObject
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->mStep1(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the m_step1 method", 0)
......@@ -584,7 +670,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize1(PyBobLearnMiscJFATrainerObje
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->finalize1(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the finalize1 method", 0)
......@@ -616,7 +702,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step2(PyBobLearnMiscJFATrainerObject
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->eStep2(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the e_step2 method", 0)
......@@ -648,7 +734,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step2(PyBobLearnMiscJFATrainerObject
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->mStep2(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the m_step2 method", 0)
......@@ -680,7 +766,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize2(PyBobLearnMiscJFATrainerObje
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->finalize2(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the finalize2 method", 0)
......@@ -712,7 +798,7 @@ static PyObject* PyBobLearnMiscJFATrainer_e_step3(PyBobLearnMiscJFATrainerObject
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->eStep3(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the e_step3 method", 0)
......@@ -744,7 +830,7 @@ static PyObject* PyBobLearnMiscJFATrainer_m_step3(PyBobLearnMiscJFATrainerObject
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->mStep3(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the m_step3 method", 0)
......@@ -776,7 +862,7 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize3(PyBobLearnMiscJFATrainerObje
&PyList_Type, &stats)) Py_RETURN_NONE;
std::vector<std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > > training_data;
if(extract_GMMStats(stats ,training_data)==0)
if(extract_GMMStats_2d(stats ,training_data)==0)
self->cxx->finalize3(*jfa_base->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the finalize3 method", 0)
......@@ -785,6 +871,39 @@ static PyObject* PyBobLearnMiscJFATrainer_finalize3(PyBobLearnMiscJFATrainerObje
}
/*** enrol ***/
static auto enrol = bob::extension::FunctionDoc(
"enrol",
"",
"",
true
)
.add_prototype("jfa_machine,features, n_inter")
.add_parameter("jfa_machine", ":py:class:`bob.learn.misc.JFAMachine`", "JFAMachine Object")
.add_parameter("features", "list(:py:class:`bob.learn.misc.GMMStats`)`", "")
.add_parameter("n_iter", "int", "Number of iterations");
static PyObject* PyBobLearnMiscJFATrainer_enrol(PyBobLearnMiscJFATrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
// Parses input arguments in a single shot
char** kwlist = enrol.kwlist(0);
PyBobLearnMiscJFAMachineObject* jfa_machine = 0;
PyObject* stats = 0;
int n_iter = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!d", kwlist, &PyBobLearnMiscJFAMachine_Type, &jfa_machine,
&PyList_Type, &stats, &n_iter)) Py_RETURN_NONE;
std::vector<boost::shared_ptr<bob::learn::misc::GMMStats> > training_data;
if(extract_GMMStats_1d(stats ,training_data)==0)
self->cxx->enrol(*jfa_machine->cxx, training_data, n_iter);
BOB_CATCH_MEMBER("cannot perform the enrol method", 0)
Py_RETURN_NONE;
}
static PyMethodDef PyBobLearnMiscJFATrainer_methods[] = {
......@@ -830,7 +949,30 @@ static PyMethodDef PyBobLearnMiscJFATrainer_methods[] = {
METH_VARARGS|METH_KEYWORDS,
m_step3.doc()
},
{
finalize1.name(),
(PyCFunction)PyBobLearnMiscJFATrainer_finalize1,
METH_VARARGS|METH_KEYWORDS,
finalize1.doc()
},
{
finalize2.name(),
(PyCFunction)PyBobLearnMiscJFATrainer_finalize2,
METH_VARARGS|METH_KEYWORDS,
finalize2.doc()
},
{
finalize3.name(),
(PyCFunction)PyBobLearnMiscJFATrainer_finalize3,
METH_VARARGS|METH_KEYWORDS,
finalize3.doc()
},
{
enrol.name(),
(PyCFunction)PyBobLearnMiscJFATrainer_enrol,
METH_VARARGS|METH_KEYWORDS,
enrol.doc()
},
{0} /* Sentinel */
};
......@@ -850,7 +992,7 @@ bool init_BobLearnMiscJFATrainer(PyObject* module)
// initialize the type struct
PyBobLearnMiscJFATrainer_Type.tp_name = JFATrainer_doc.name();
PyBobLearnMiscJFATrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscJFATrainerObject);
PyBobLearnMiscJFATrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT;
PyBobLearnMiscJFATrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance;
PyBobLearnMiscJFATrainer_Type.tp_doc = JFATrainer_doc.doc();
// set the functions
......
......@@ -13,8 +13,8 @@ import numpy.linalg
import bob.core.random
from . import GMMStats, GMMMachine, JFABase, JFAMachine, ISVBase, ISVMachine, \
JFATrainer, ISVTrainer
from . import GMMStats, GMMMachine, JFABase, JFAMachine, ISVBase, ISVMachine, JFATrainer
#, ISVTrainer
def equals(x, y, epsilon):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment