From 115b521bf578bfb545929876675b64489206cf43 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 4 Feb 2015 00:40:43 +0100 Subject: [PATCH] Binding EMPCATrainer --- bob/learn/misc/cpp/EMPCATrainer.cpp | 41 +- bob/learn/misc/empca_trainer.cpp | 377 ++++++++++++++++++ .../include/bob.learn.misc/EMPCATrainer.h | 28 +- bob/learn/misc/main.cpp | 3 +- bob/learn/misc/main.h | 14 + setup.py | 5 +- 6 files changed, 434 insertions(+), 34 deletions(-) create mode 100644 bob/learn/misc/empca_trainer.cpp diff --git a/bob/learn/misc/cpp/EMPCATrainer.cpp b/bob/learn/misc/cpp/EMPCATrainer.cpp index 6149393..60dffc7 100644 --- a/bob/learn/misc/cpp/EMPCATrainer.cpp +++ b/bob/learn/misc/cpp/EMPCATrainer.cpp @@ -18,10 +18,9 @@ #include <bob.math/inv.h> #include <bob.math/stats.h> -bob::learn::misc::EMPCATrainer::EMPCATrainer(double convergence_threshold, - size_t max_iterations, bool compute_likelihood): - EMTrainer<bob::learn::linear::Machine, blitz::Array<double,2> >(convergence_threshold, - max_iterations, compute_likelihood), +bob::learn::misc::EMPCATrainer::EMPCATrainer(bool compute_likelihood): + m_compute_likelihood(compute_likelihood), + m_rng(new boost::mt19937()), m_S(0,0), m_z_first_order(0,0), m_z_second_order(0,0,0), m_inW(0,0), m_invM(0,0), m_sigma2(0), m_f_log2pi(0), @@ -33,8 +32,8 @@ bob::learn::misc::EMPCATrainer::EMPCATrainer(double convergence_threshold, } bob::learn::misc::EMPCATrainer::EMPCATrainer(const bob::learn::misc::EMPCATrainer& other): - EMTrainer<bob::learn::linear::Machine, blitz::Array<double,2> >(other.m_convergence_threshold, - other.m_max_iterations, other.m_compute_likelihood), + m_compute_likelihood(other.m_compute_likelihood), + m_rng(other.m_rng), m_S(bob::core::array::ccopy(other.m_S)), m_z_first_order(bob::core::array::ccopy(other.m_z_first_order)), m_z_second_order(bob::core::array::ccopy(other.m_z_second_order)), @@ -62,8 +61,8 @@ bob::learn::misc::EMPCATrainer& bob::learn::misc::EMPCATrainer::operator= { if (this != &other) { - bob::learn::misc::EMTrainer<bob::learn::linear::Machine, - blitz::Array<double,2> >::operator=(other); + m_rng = other.m_rng; + m_compute_likelihood = other.m_compute_likelihood; m_S = bob::core::array::ccopy(other.m_S); m_z_first_order = bob::core::array::ccopy(other.m_z_first_order); m_z_second_order = bob::core::array::ccopy(other.m_z_second_order); @@ -87,8 +86,8 @@ bob::learn::misc::EMPCATrainer& bob::learn::misc::EMPCATrainer::operator= bool bob::learn::misc::EMPCATrainer::operator== (const bob::learn::misc::EMPCATrainer &other) const { - return bob::learn::misc::EMTrainer<bob::learn::linear::Machine, - blitz::Array<double,2> >::operator==(other) && + return m_compute_likelihood == other.m_compute_likelihood && + m_rng == other.m_rng && bob::core::array::isEqual(m_S, other.m_S) && bob::core::array::isEqual(m_z_first_order, other.m_z_first_order) && bob::core::array::isEqual(m_z_second_order, other.m_z_second_order) && @@ -108,15 +107,15 @@ bool bob::learn::misc::EMPCATrainer::is_similar_to (const bob::learn::misc::EMPCATrainer &other, const double r_epsilon, const double a_epsilon) const { - return bob::learn::misc::EMTrainer<bob::learn::linear::Machine, - blitz::Array<double,2> >::is_similar_to(other, r_epsilon, a_epsilon) && - bob::core::array::isClose(m_S, other.m_S, r_epsilon, a_epsilon) && - bob::core::array::isClose(m_z_first_order, other.m_z_first_order, r_epsilon, a_epsilon) && - bob::core::array::isClose(m_z_second_order, other.m_z_second_order, r_epsilon, a_epsilon) && - bob::core::array::isClose(m_inW, other.m_inW, r_epsilon, a_epsilon) && - bob::core::array::isClose(m_invM, other.m_invM, r_epsilon, a_epsilon) && - bob::core::isClose(m_sigma2, other.m_sigma2, r_epsilon, a_epsilon) && - bob::core::isClose(m_f_log2pi, other.m_f_log2pi, r_epsilon, a_epsilon); + return m_compute_likelihood == other.m_compute_likelihood && + m_rng == other.m_rng && + bob::core::array::isClose(m_S, other.m_S, r_epsilon, a_epsilon) && + bob::core::array::isClose(m_z_first_order, other.m_z_first_order, r_epsilon, a_epsilon) && + bob::core::array::isClose(m_z_second_order, other.m_z_second_order, r_epsilon, a_epsilon) && + bob::core::array::isClose(m_inW, other.m_inW, r_epsilon, a_epsilon) && + bob::core::array::isClose(m_invM, other.m_invM, r_epsilon, a_epsilon) && + bob::core::isClose(m_sigma2, other.m_sigma2, r_epsilon, a_epsilon) && + bob::core::isClose(m_f_log2pi, other.m_f_log2pi, r_epsilon, a_epsilon); } void bob::learn::misc::EMPCATrainer::initialize(bob::learn::linear::Machine& machine, @@ -137,10 +136,6 @@ void bob::learn::misc::EMPCATrainer::initialize(bob::learn::linear::Machine& mac computeInvM(); } -void bob::learn::misc::EMPCATrainer::finalize(bob::learn::linear::Machine& machine, - const blitz::Array<double,2>& ar) -{ -} void bob::learn::misc::EMPCATrainer::initMembers( const bob::learn::linear::Machine& machine, diff --git a/bob/learn/misc/empca_trainer.cpp b/bob/learn/misc/empca_trainer.cpp new file mode 100644 index 0000000..6a34239 --- /dev/null +++ b/bob/learn/misc/empca_trainer.cpp @@ -0,0 +1,377 @@ +/** + * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + * @date Tue 03 Fev 11:22:00 2015 + * + * @brief Python API for bob::learn::em + * + * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland + */ + +#include "main.h" + +/******************************************************************/ +/************ Constructor Section *********************************/ +/******************************************************************/ + +static auto EMPCATrainer_doc = bob::extension::ClassDoc( + BOB_EXT_MODULE_PREFIX "._EMPCATrainer", + "" + +).add_constructor( + bob::extension::FunctionDoc( + "__init__", + "Creates a EMPCATrainer", + "", + true + ) + .add_prototype("compute_likelihood","") + .add_prototype("other","") + .add_prototype("","") + + .add_parameter("other", ":py:class:`bob.learn.misc.EMPCATrainer`", "A EMPCATrainer object to be copied.") + +); + + +static int PyBobLearnMiscEMPCATrainer_init_copy(PyBobLearnMiscEMPCATrainerObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = EMPCATrainer_doc.kwlist(1); + PyBobLearnMiscEMPCATrainerObject* tt; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscEMPCATrainer_Type, &tt)){ + EMPCATrainer_doc.print_usage(); + return -1; + } + + self->cxx.reset(new bob::learn::misc::EMPCATrainer(*tt->cxx)); + return 0; +} + +static int PyBobLearnMiscEMPCATrainer_init_number(PyBobLearnMiscEMPCATrainerObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = EMPCATrainer_doc.kwlist(0); + double convergence_threshold = 0.0001; + //Parsing the input argments + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "d", kwlist, &convergence_threshold)) + return -1; + + if(convergence_threshold < 0){ + PyErr_Format(PyExc_TypeError, "convergence_threshold argument must be greater than to zero"); + return -1; + } + + self->cxx.reset(new bob::learn::misc::EMPCATrainer(convergence_threshold)); + return 0; +} + +static int PyBobLearnMiscEMPCATrainer_init(PyBobLearnMiscEMPCATrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0); + + switch (nargs) { + + case 0:{ //default initializer () + self->cxx.reset(new bob::learn::misc::EMPCATrainer()); + return 0; + } + case 1:{ + //Reading the input argument + PyObject* arg = 0; + if (PyTuple_Size(args)) + arg = PyTuple_GET_ITEM(args, 0); + else { + PyObject* tmp = PyDict_Values(kwargs); + auto tmp_ = make_safe(tmp); + arg = PyList_GET_ITEM(tmp, 0); + } + + // If the constructor input is EMPCATrainer object + if (PyBobLearnMiscEMPCATrainer_Check(arg)) + return PyBobLearnMiscEMPCATrainer_init_copy(self, args, kwargs); + else if(PyString_Check(arg)) + return PyBobLearnMiscEMPCATrainer_init_number(self, args, kwargs); + } + default:{ + PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 0 or 1 arguments, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs); + EMPCATrainer_doc.print_usage(); + return -1; + } + } + BOB_CATCH_MEMBER("cannot create EMPCATrainer", 0) + return 0; +} + + +static void PyBobLearnMiscEMPCATrainer_delete(PyBobLearnMiscEMPCATrainerObject* self) { + self->cxx.reset(); + Py_TYPE(self)->tp_free((PyObject*)self); +} + + +int PyBobLearnMiscEMPCATrainer_Check(PyObject* o) { + return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscEMPCATrainer_Type)); +} + + +static PyObject* PyBobLearnMiscEMPCATrainer_RichCompare(PyBobLearnMiscEMPCATrainerObject* self, PyObject* other, int op) { + BOB_TRY + + if (!PyBobLearnMiscEMPCATrainer_Check(other)) { + PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name); + return 0; + } + auto other_ = reinterpret_cast<PyBobLearnMiscEMPCATrainerObject*>(other); + switch (op) { + case Py_EQ: + if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE; + case Py_NE: + if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE; + default: + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + BOB_CATCH_MEMBER("cannot compare EMPCATrainer objects", 0) +} + + +/******************************************************************/ +/************ Variables Section ***********************************/ +/******************************************************************/ + + +/***** rng *****/ +static auto rng = bob::extension::VariableDoc( + "rng", + "str", + "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.", + "" +); +PyObject* PyBobLearnMiscEMPCATrainer_getRng(PyBobLearnMiscEMPCATrainerObject* self, void*) { + BOB_TRY + //Allocating the correspondent python object + + PyBoostMt19937Object* retval = + (PyBoostMt19937Object*)PyBoostMt19937_Type.tp_alloc(&PyBoostMt19937_Type, 0); + + retval->rng = self->cxx->getRng().get(); + return Py_BuildValue("O", retval); + BOB_CATCH_MEMBER("Rng method could not be read", 0) +} +int PyBobLearnMiscEMPCATrainer_setRng(PyBobLearnMiscEMPCATrainerObject* self, PyObject* value, void*) { + BOB_TRY + + if (!PyBoostMt19937_Check(value)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects an PyBoostMt19937_Check", Py_TYPE(self)->tp_name, rng.name()); + return -1; + } + + PyBoostMt19937Object* boostObject = 0; + PyBoostMt19937_Converter(value, &boostObject); + self->cxx->setRng((boost::shared_ptr<boost::mt19937>)boostObject->rng); + + return 0; + BOB_CATCH_MEMBER("Rng could not be set", 0) +} + + + +static PyGetSetDef PyBobLearnMiscEMPCATrainer_getseters[] = { + { + rng.name(), + (getter)PyBobLearnMiscEMPCATrainer_getRng, + (setter)PyBobLearnMiscEMPCATrainer_setRng, + rng.doc(), + 0 + }, + {0} // Sentinel +}; + + +/******************************************************************/ +/************ Functions Section ***********************************/ +/******************************************************************/ + +/*** initialize ***/ +static auto initialize = bob::extension::FunctionDoc( + "initialize", + "", + "", + true +) +.add_prototype("linear_machine,data") +.add_parameter("linear_machine", ":py:class:`bob.learn.linear.Machine`", "LinearMachine Object") +.add_parameter("data", "array_like <float, 2D>", "Input data"); +static PyObject* PyBobLearnMiscEMPCATrainer_initialize(PyBobLearnMiscEMPCATrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = initialize.kwlist(0); + + PyBobLearnLinearMachineObject* linear_machine = 0; + PyBlitzArrayObject* data = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnLinearMachine_Type, &linear_machine, + &PyBlitzArray_Converter, &data)) Py_RETURN_NONE; + auto data_ = make_safe(data); + + self->cxx->initialize(*linear_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data)); + + BOB_CATCH_MEMBER("cannot perform the initialize method", 0) + + Py_RETURN_NONE; +} + + +/*** eStep ***/ +static auto eStep = bob::extension::FunctionDoc( + "eStep", + "", + "", + true +) +.add_prototype("linear_machine,data") +.add_parameter("linear_machine", ":py:class:`bob.learn.linear.Machine`", "LinearMachine Object") +.add_parameter("data", "array_like <float, 2D>", "Input data"); +static PyObject* PyBobLearnMiscEMPCATrainer_eStep(PyBobLearnMiscEMPCATrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = eStep.kwlist(0); + + PyBobLearnLinearMachineObject* linear_machine; + PyBlitzArrayObject* data = 0; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnLinearMachine_Type, &linear_machine, + &PyBlitzArray_Converter, &data)) Py_RETURN_NONE; + auto data_ = make_safe(data); + + self->cxx->eStep(*linear_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data)); + + + BOB_CATCH_MEMBER("cannot perform the eStep method", 0) + + Py_RETURN_NONE; +} + + +/*** mStep ***/ +static auto mStep = bob::extension::FunctionDoc( + "mStep", + "", + 0, + true +) +.add_prototype("linear_machine,data") +.add_parameter("linear_machine", ":py:class:`bob.learn.misc.LinearMachine`", "LinearMachine Object") +.add_parameter("data", "array_like <float, 2D>", "Input data"); +static PyObject* PyBobLearnMiscEMPCATrainer_mStep(PyBobLearnMiscEMPCATrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = mStep.kwlist(0); + + PyBobLearnLinearMachineObject* linear_machine; + PyBlitzArrayObject* data = 0; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnLinearMachine_Type, &linear_machine, + &PyBlitzArray_Converter, &data)) Py_RETURN_NONE; + auto data_ = make_safe(data); + + self->cxx->mStep(*linear_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data)); + + + BOB_CATCH_MEMBER("cannot perform the mStep method", 0) + + Py_RETURN_NONE; +} + + +/*** computeLikelihood ***/ +static auto compute_likelihood = bob::extension::FunctionDoc( + "compute_likelihood", + "", + 0, + true +) +.add_prototype("linear_machine,data") +.add_parameter("linear_machine", ":py:class:`bob.learn.misc.LinearMachine`", "LinearMachine Object"); +static PyObject* PyBobLearnMiscEMPCATrainer_compute_likelihood(PyBobLearnMiscEMPCATrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = compute_likelihood.kwlist(0); + + PyBobLearnLinearMachineObject* linear_machine; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnLinearMachine_Type, &linear_machine)) Py_RETURN_NONE; + + double value = self->cxx->computeLikelihood(*linear_machine->cxx); + return Py_BuildValue("d", value); + + BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0) +} + + + +static PyMethodDef PyBobLearnMiscEMPCATrainer_methods[] = { + { + initialize.name(), + (PyCFunction)PyBobLearnMiscEMPCATrainer_initialize, + METH_VARARGS|METH_KEYWORDS, + initialize.doc() + }, + { + eStep.name(), + (PyCFunction)PyBobLearnMiscEMPCATrainer_eStep, + METH_VARARGS|METH_KEYWORDS, + eStep.doc() + }, + { + mStep.name(), + (PyCFunction)PyBobLearnMiscEMPCATrainer_mStep, + METH_VARARGS|METH_KEYWORDS, + mStep.doc() + }, + { + compute_likelihood.name(), + (PyCFunction)PyBobLearnMiscEMPCATrainer_compute_likelihood, + METH_VARARGS|METH_KEYWORDS, + compute_likelihood.doc() + }, + {0} /* Sentinel */ +}; + + +/******************************************************************/ +/************ Module Section **************************************/ +/******************************************************************/ + +// Define the Gaussian type struct; will be initialized later +PyTypeObject PyBobLearnMiscEMPCATrainer_Type = { + PyVarObject_HEAD_INIT(0,0) + 0 +}; + +bool init_BobLearnMiscEMPCATrainer(PyObject* module) +{ + // initialize the type struct + PyBobLearnMiscEMPCATrainer_Type.tp_name = EMPCATrainer_doc.name(); + PyBobLearnMiscEMPCATrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscEMPCATrainerObject); + PyBobLearnMiscEMPCATrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance + PyBobLearnMiscEMPCATrainer_Type.tp_doc = EMPCATrainer_doc.doc(); + + // set the functions + PyBobLearnMiscEMPCATrainer_Type.tp_new = PyType_GenericNew; + PyBobLearnMiscEMPCATrainer_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscEMPCATrainer_init); + PyBobLearnMiscEMPCATrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscEMPCATrainer_delete); + PyBobLearnMiscEMPCATrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscEMPCATrainer_RichCompare); + PyBobLearnMiscEMPCATrainer_Type.tp_methods = PyBobLearnMiscEMPCATrainer_methods; + PyBobLearnMiscEMPCATrainer_Type.tp_getset = PyBobLearnMiscEMPCATrainer_getseters; + PyBobLearnMiscEMPCATrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscEMPCATrainer_compute_likelihood); + + + // check that everything is fine + if (PyType_Ready(&PyBobLearnMiscEMPCATrainer_Type) < 0) return false; + + // add the type to the module + Py_INCREF(&PyBobLearnMiscEMPCATrainer_Type); + return PyModule_AddObject(module, "_EMPCATrainer", (PyObject*)&PyBobLearnMiscEMPCATrainer_Type) >= 0; +} + diff --git a/bob/learn/misc/include/bob.learn.misc/EMPCATrainer.h b/bob/learn/misc/include/bob.learn.misc/EMPCATrainer.h index 119968e..753c78c 100644 --- a/bob/learn/misc/include/bob.learn.misc/EMPCATrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/EMPCATrainer.h @@ -11,7 +11,6 @@ #ifndef BOB_LEARN_MISC_EMPCA_TRAINER_H #define BOB_LEARN_MISC_EMPCA_TRAINER_H -#include <bob.learn.misc/EMTrainer.h> #include <bob.learn.linear/machine.h> #include <blitz/array.h> @@ -38,7 +37,7 @@ namespace bob { namespace learn { namespace misc { * - \f$\epsilon\f$ is the noise of the data (dimension \f$f\f$) * Gaussian with zero-mean and covariance matrix \f$\sigma^2 Id\f$ */ -class EMPCATrainer: public EMTrainer<bob::learn::linear::Machine, blitz::Array<double,2> > +class EMPCATrainer { public: //api /** @@ -46,8 +45,7 @@ class EMPCATrainer: public EMTrainer<bob::learn::linear::Machine, blitz::Array<d * resulting components in the linear machine and set it up to * extract the variable means automatically. */ - EMPCATrainer(double convergence_threshold=0.001, - size_t max_iterations=10, bool compute_likelihood=true); + EMPCATrainer(bool compute_likelihood=true); /** * @brief Copy constructor @@ -85,11 +83,6 @@ class EMPCATrainer: public EMTrainer<bob::learn::linear::Machine, blitz::Array<d */ virtual void initialize(bob::learn::linear::Machine& machine, const blitz::Array<double,2>& ar); - /** - * @brief This methods performs some actions after the EM loop. - */ - virtual void finalize(bob::learn::linear::Machine& machine, - const blitz::Array<double,2>& ar); /** * @brief Calculates and saves statistics across the dataset, and saves @@ -123,7 +116,24 @@ class EMPCATrainer: public EMTrainer<bob::learn::linear::Machine, blitz::Array<d */ double getSigma2() const { return m_sigma2; } + /** + * @brief Sets the Random Number Generator + */ + void setRng(const boost::shared_ptr<boost::mt19937> rng) + { m_rng = rng; } + + /** + * @brief Gets the Random Number Generator + */ + const boost::shared_ptr<boost::mt19937> getRng() const + { return m_rng; } + + private: //representation + + bool m_compute_likelihood; + boost::shared_ptr<boost::mt19937> m_rng; + blitz::Array<double,2> m_S; /// Covariance of the training data (required only if we need to compute the log likelihood) blitz::Array<double,2> m_z_first_order; /// Current mean of the \f$z_{n}\f$ latent variable blitz::Array<double,3> m_z_second_order; /// Current covariance of the \f$z_{n}\f$ latent variable diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp index c4875f8..a56592c 100644 --- a/bob/learn/misc/main.cpp +++ b/bob/learn/misc/main.cpp @@ -81,8 +81,9 @@ static PyObject* create_module (void) { if (!init_BobLearnMiscIVectorTrainer(module)) return 0; if (!init_BobLearnMiscPLDABase(module)) return 0; - if (!init_BobLearnMiscPLDAMachine(module)) return 0; + if (!init_BobLearnMiscPLDAMachine(module)) return 0; + if (!init_BobLearnMiscEMPCATrainer(module)) return 0; static void* PyBobLearnMisc_API[PyBobLearnMisc_API_pointers]; diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h index 830cd3e..c3f9d3d 100644 --- a/bob/learn/misc/main.h +++ b/bob/learn/misc/main.h @@ -12,6 +12,7 @@ #include <bob.blitz/cleanup.h> #include <bob.core/random_api.h> #include <bob.io.base/api.h> +#include <bob.learn.linear/api.h> #include <bob.extension/documentation.h> #define BOB_LEARN_EM_MODULE @@ -39,6 +40,8 @@ #include <bob.learn.misc/IVectorMachine.h> #include <bob.learn.misc/IVectorTrainer.h> +#include <bob.learn.misc/EMPCATrainer.h> + #include <bob.learn.misc/PLDAMachine.h> #include <bob.learn.misc/ZTNorm.h> @@ -279,5 +282,16 @@ bool init_BobLearnMiscPLDAMachine(PyObject* module); int PyBobLearnMiscPLDAMachine_Check(PyObject* o); +// EMPCATrainer +typedef struct { + PyObject_HEAD + boost::shared_ptr<bob::learn::misc::EMPCATrainer> cxx; +} PyBobLearnMiscEMPCATrainerObject; + +extern PyTypeObject PyBobLearnMiscEMPCATrainer_Type; +bool init_BobLearnMiscEMPCATrainer(PyObject* module); +int PyBobLearnMiscEMPCATrainer_Check(PyObject* o); + + #endif // BOB_LEARN_EM_MAIN_H diff --git a/setup.py b/setup.py index 2090450..e084e8a 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ setup( "bob/learn/misc/cpp/JFATrainer.cpp", "bob/learn/misc/cpp/ISVTrainer.cpp", - #"bob/learn/misc/cpp/EMPCATrainer.cpp", + "bob/learn/misc/cpp/EMPCATrainer.cpp", "bob/learn/misc/cpp/GMMBaseTrainer.cpp", "bob/learn/misc/cpp/IVectorTrainer.cpp", "bob/learn/misc/cpp/KMeansTrainer.cpp", @@ -130,6 +130,9 @@ setup( "bob/learn/misc/plda_base.cpp", "bob/learn/misc/plda_machine.cpp", + + "bob/learn/misc/empca_trainer.cpp", + "bob/learn/misc/ztnorm.cpp", "bob/learn/misc/main.cpp", -- GitLab