From e530ba1fb7c53d95024083a4cf5dd7400be26c75 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Thu, 15 Jan 2015 09:56:48 +0100 Subject: [PATCH] Started the binding --- bob/learn/misc/kmeans_trainer.cpp | 42 ++++++++++++++++++------------- bob/learn/misc/main.h | 2 +- setup.py | 1 + 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/bob/learn/misc/kmeans_trainer.cpp b/bob/learn/misc/kmeans_trainer.cpp index d143d16..8ce8bc3 100644 --- a/bob/learn/misc/kmeans_trainer.cpp +++ b/bob/learn/misc/kmeans_trainer.cpp @@ -13,12 +13,13 @@ /************ Constructor Section *********************************/ /******************************************************************/ + static auto KMeansTrainer_doc = bob::extension::ClassDoc( BOB_EXT_MODULE_PREFIX ".KMeansTrainer", - "Trains a KMeans machine.\n" - "This class implements the expectation-maximization algorithm for a k-means machine.\n" - "See Section 9.1 of Bishop, \"Pattern recognition and machine learning\", 2006\n" - "It uses a random initialization of the means followed by the expectation-maximization algorithm", + "Trains a KMeans machine." + "This class implements the expectation-maximization algorithm for a k-means machine." + "See Section 9.1 of Bishop, \"Pattern recognition and machine learning\", 2006" + "It uses a random initialization of the means followed by the expectation-maximization algorithm" ).add_constructor( bob::extension::FunctionDoc( "__init__", @@ -36,33 +37,41 @@ static auto KMeansTrainer_doc = bob::extension::ClassDoc( ); -static int PyBobLearnMiscKMeansTrainer_init_copy(PyBobLearnMiscKMeansMachineObject* self, PyObject* args, PyObject* kwargs) { +static int PyBobLearnMiscKMeansTrainer_init_copy(PyBobLearnMiscKMeansTrainerObject* self, PyObject* args, PyObject* kwargs) { - char** kwlist = KMeansMachine_doc.kwlist(1); - PyBobLearnMiscKMeansMachineObject* tt; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscKMeansMachine_Type, &tt)){ - KMeansMachine_doc.print_usage(); + char** kwlist = KMeansTrainer_doc.kwlist(1); + PyBobLearnMiscKMeansTrainerObject* tt; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscKMeansTrainer_Type, &tt)){ + KMeansTrainer_doc.print_usage(); return -1; } - self->cxx.reset(new bob::learn::misc::KMeansMachine(*tt->cxx)); + self->cxx.reset(new bob::learn::misc::KMeansTrainer(*tt->cxx)); return 0; } +static int PyBobLearnMiscKMeansTrainer_init(PyBobLearnMiscKMeansTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + // get the number of command line arguments + //int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0); + + BOB_CATCH_MEMBER("cannot create KMeansTrainer", 0) + return 0; +} -static void PyBobLearnMiscKMeansMachine_delete(PyBobLearnMiscKMeansMachineObject* self) { +static void PyBobLearnMiscKMeansTrainer_delete(PyBobLearnMiscKMeansTrainerObject* self) { self->cxx.reset(); Py_TYPE(self)->tp_free((PyObject*)self); } -int PyBobLearnMiscKMeansMachine_Check(PyObject* o) { - return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscKMeansMachine_Type)); +int PyBobLearnMiscKMeansTrainer_Check(PyObject* o) { + return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscKMeansTrainer_Type)); } - /******************************************************************/ /************ Variables Section ***********************************/ /******************************************************************/ @@ -80,7 +89,6 @@ static PyGetSetDef PyBobLearnMiscKMeansTrainer_getseters[] = { - static PyMethodDef PyBobLearnMiscKMeansTrainer_methods[] = { {0} /* Sentinel */ }; @@ -107,7 +115,7 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module) // set the functions PyBobLearnMiscKMeansTrainer_Type.tp_new = PyType_GenericNew; PyBobLearnMiscKMeansTrainer_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscKMeansTrainer_init); - PyBobLearnMiscKMeansTrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscKTrainer_delete); + PyBobLearnMiscKMeansTrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscKMeansTrainer_delete); //PyBobLearnMiscKMeansTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscKMeansTrainer_RichCompare); PyBobLearnMiscKMeansTrainer_Type.tp_methods = PyBobLearnMiscKMeansTrainer_methods; PyBobLearnMiscKMeansTrainer_Type.tp_getset = PyBobLearnMiscKMeansTrainer_getseters; @@ -119,6 +127,6 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module) // add the type to the module Py_INCREF(&PyBobLearnMiscKMeansTrainer_Type); - return PyModule_AddObject(module, "_KMeansTrainer", (PyObject*)&PyBobLearnMiscKMeansTrainer_Type) >= 0; + return PyModule_AddObject(module, "KMeansTrainer", (PyObject*)&PyBobLearnMiscKMeansTrainer_Type) >= 0; } diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h index 08d3919..a6efaa7 100644 --- a/bob/learn/misc/main.h +++ b/bob/learn/misc/main.h @@ -21,7 +21,7 @@ #include <bob.learn.misc/GMMStats.h> #include <bob.learn.misc/GMMMachine.h> #include <bob.learn.misc/KMeansMachine.h> - +#include <bob.learn.misc/KMeansTrainer.h> #if PY_VERSION_HEX >= 0x03000000 #define PyInt_Check PyLong_Check diff --git a/setup.py b/setup.py index 13a34e4..57f2a77 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,7 @@ setup( "bob/learn/misc/gmm_stats.cpp", "bob/learn/misc/gmm_machine.cpp", "bob/learn/misc/kmeans_machine.cpp", + "bob/learn/misc/kmeans_trainer.cpp", "bob/learn/misc/main.cpp", ], -- GitLab