Skip to content
Snippets Groups Projects
Commit e530ba1f authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Started the binding

parent c9f8ab18
No related branches found
No related tags found
No related merge requests found
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
/************ Constructor Section *********************************/ /************ Constructor Section *********************************/
/******************************************************************/ /******************************************************************/
static auto KMeansTrainer_doc = bob::extension::ClassDoc( static auto KMeansTrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".KMeansTrainer", BOB_EXT_MODULE_PREFIX ".KMeansTrainer",
"Trains a KMeans machine.\n" "Trains a KMeans machine."
"This class implements the expectation-maximization algorithm for a k-means machine.\n" "This class implements the expectation-maximization algorithm for a k-means machine."
"See Section 9.1 of Bishop, \"Pattern recognition and machine learning\", 2006\n" "See Section 9.1 of Bishop, \"Pattern recognition and machine learning\", 2006"
"It uses a random initialization of the means followed by the expectation-maximization algorithm", "It uses a random initialization of the means followed by the expectation-maximization algorithm"
).add_constructor( ).add_constructor(
bob::extension::FunctionDoc( bob::extension::FunctionDoc(
"__init__", "__init__",
...@@ -36,33 +37,41 @@ static auto KMeansTrainer_doc = bob::extension::ClassDoc( ...@@ -36,33 +37,41 @@ static auto KMeansTrainer_doc = bob::extension::ClassDoc(
); );
static int PyBobLearnMiscKMeansTrainer_init_copy(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) { static int PyBobLearnMiscKMeansTrainer_init_copy(PyBobLearnMiscKMeansTrainerObject* self, PyObject* args, PyObject* kwargs) {
char** kwlist = KMeansMachine_doc.kwlist(1); char** kwlist = KMeansTrainer_doc.kwlist(1);
PyBobLearnMiscKMeansMachineObject* tt; PyBobLearnMiscKMeansTrainerObject* tt;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscKMeansMachine_Type, &tt)){ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscKMeansTrainer_Type, &tt)){
KMeansMachine_doc.print_usage(); KMeansTrainer_doc.print_usage();
return -1; return -1;
} }
self->cxx.reset(new bob::learn::misc::KMeansMachine(*tt->cxx)); self->cxx.reset(new bob::learn::misc::KMeansTrainer(*tt->cxx));
return 0; return 0;
} }
static int PyBobLearnMiscKMeansTrainer_init(PyBobLearnMiscKMeansTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
// get the number of command line arguments
//int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
BOB_CATCH_MEMBER("cannot create KMeansTrainer", 0)
return 0;
}
static void PyBobLearnMiscKMeansMachine_delete(PyBobLearnMiscKMeansMachineObject* self) { static void PyBobLearnMiscKMeansTrainer_delete(PyBobLearnMiscKMeansTrainerObject* self) {
self->cxx.reset(); self->cxx.reset();
Py_TYPE(self)->tp_free((PyObject*)self); Py_TYPE(self)->tp_free((PyObject*)self);
} }
int PyBobLearnMiscKMeansMachine_Check(PyObject* o) { int PyBobLearnMiscKMeansTrainer_Check(PyObject* o) {
return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscKMeansMachine_Type)); return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscKMeansTrainer_Type));
} }
/******************************************************************/ /******************************************************************/
/************ Variables Section ***********************************/ /************ Variables Section ***********************************/
/******************************************************************/ /******************************************************************/
...@@ -80,7 +89,6 @@ static PyGetSetDef PyBobLearnMiscKMeansTrainer_getseters[] = { ...@@ -80,7 +89,6 @@ static PyGetSetDef PyBobLearnMiscKMeansTrainer_getseters[] = {
static PyMethodDef PyBobLearnMiscKMeansTrainer_methods[] = { static PyMethodDef PyBobLearnMiscKMeansTrainer_methods[] = {
{0} /* Sentinel */ {0} /* Sentinel */
}; };
...@@ -107,7 +115,7 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module) ...@@ -107,7 +115,7 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module)
// set the functions // set the functions
PyBobLearnMiscKMeansTrainer_Type.tp_new = PyType_GenericNew; PyBobLearnMiscKMeansTrainer_Type.tp_new = PyType_GenericNew;
PyBobLearnMiscKMeansTrainer_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscKMeansTrainer_init); PyBobLearnMiscKMeansTrainer_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscKMeansTrainer_init);
PyBobLearnMiscKMeansTrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscKTrainer_delete); PyBobLearnMiscKMeansTrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscKMeansTrainer_delete);
//PyBobLearnMiscKMeansTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscKMeansTrainer_RichCompare); //PyBobLearnMiscKMeansTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscKMeansTrainer_RichCompare);
PyBobLearnMiscKMeansTrainer_Type.tp_methods = PyBobLearnMiscKMeansTrainer_methods; PyBobLearnMiscKMeansTrainer_Type.tp_methods = PyBobLearnMiscKMeansTrainer_methods;
PyBobLearnMiscKMeansTrainer_Type.tp_getset = PyBobLearnMiscKMeansTrainer_getseters; PyBobLearnMiscKMeansTrainer_Type.tp_getset = PyBobLearnMiscKMeansTrainer_getseters;
...@@ -119,6 +127,6 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module) ...@@ -119,6 +127,6 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module)
// add the type to the module // add the type to the module
Py_INCREF(&PyBobLearnMiscKMeansTrainer_Type); Py_INCREF(&PyBobLearnMiscKMeansTrainer_Type);
return PyModule_AddObject(module, "_KMeansTrainer", (PyObject*)&PyBobLearnMiscKMeansTrainer_Type) >= 0; return PyModule_AddObject(module, "KMeansTrainer", (PyObject*)&PyBobLearnMiscKMeansTrainer_Type) >= 0;
} }
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <bob.learn.misc/GMMStats.h> #include <bob.learn.misc/GMMStats.h>
#include <bob.learn.misc/GMMMachine.h> #include <bob.learn.misc/GMMMachine.h>
#include <bob.learn.misc/KMeansMachine.h> #include <bob.learn.misc/KMeansMachine.h>
#include <bob.learn.misc/KMeansTrainer.h>
#if PY_VERSION_HEX >= 0x03000000 #if PY_VERSION_HEX >= 0x03000000
#define PyInt_Check PyLong_Check #define PyInt_Check PyLong_Check
......
...@@ -104,6 +104,7 @@ setup( ...@@ -104,6 +104,7 @@ setup(
"bob/learn/misc/gmm_stats.cpp", "bob/learn/misc/gmm_stats.cpp",
"bob/learn/misc/gmm_machine.cpp", "bob/learn/misc/gmm_machine.cpp",
"bob/learn/misc/kmeans_machine.cpp", "bob/learn/misc/kmeans_machine.cpp",
"bob/learn/misc/kmeans_trainer.cpp",
"bob/learn/misc/main.cpp", "bob/learn/misc/main.cpp",
], ],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment