diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp index 388a218c99b16c28b92ac46f90ac42607f3baec1..f79e852b684f86dab5b3ce10808c969270863eff 100644 --- a/bob/learn/misc/main.cpp +++ b/bob/learn/misc/main.cpp @@ -41,6 +41,7 @@ static PyObject* create_module (void) { if (PyModule_AddStringConstant(module, "__version__", BOB_EXT_MODULE_VERSION) < 0) return 0; if (!init_BobLearnMiscGaussian(module)) return 0; + if (!init_BobLearnMiscGMMStats(module)) return 0; static void* PyBobLearnMisc_API[PyBobLearnMisc_API_pointers]; diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h index bb40bec0d83308713311ab3ae30e47f4b6a2dffa..1759cbdd865fdaee64a8d9232ea41058fcbceea8 100644 --- a/bob/learn/misc/main.h +++ b/bob/learn/misc/main.h @@ -18,6 +18,7 @@ #include <bob.learn.misc/api.h> #include <bob.learn.misc/Gaussian.h> +#include <bob.learn.misc/GMMStats.h> #if PY_VERSION_HEX >= 0x03000000 @@ -69,4 +70,16 @@ bool init_BobLearnMiscGaussian(PyObject* module); int PyBobLearnMiscGaussian_Check(PyObject* o); +// GMMStats +typedef struct { + PyObject_HEAD + boost::shared_ptr<bob::learn::misc::GMMStats> cxx; +} PyBobLearnMiscGMMStatsObject; + +extern PyTypeObject PyBobLearnMiscGMMStats_Type; +bool init_BobLearnMiscGMMStats(PyObject* module); +int PyBobLearnMiscGMMStats_Check(PyObject* o); + + + #endif // BOB_LEARN_EM_MAIN_H diff --git a/bob/learn/misc/test_gmm.py b/bob/learn/misc/test_gmm.py index cd981742dc6ba95a6e4057961438de1dc3389bc9..bb3855d781397a12c591952f5f9b6c7971766373 100644 --- a/bob/learn/misc/test_gmm.py +++ b/bob/learn/misc/test_gmm.py @@ -15,11 +15,11 @@ import tempfile import bob.io.base from bob.io.base.test_utils import datafile -from . import GMMStats, GMMMachine +from . import GMMStats +#, GMMMachine def test_GMMStats(): # Test a GMMStats - # Initializes a GMMStats gs = GMMStats(2,3) log_likelihood = -3. diff --git a/setup.py b/setup.py index feabf7210c6b71c12787548dbe7cc3737073be06..c59f475ab85a74e2cdd84db646290b0a8a0dcbfc 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ setup( [ "bob/learn/misc/cpp/Gaussian.cpp", #"bob/learn/misc/cpp/GMMMachine.cpp", - #"bob/learn/misc/cpp/GMMStats.cpp", + "bob/learn/misc/cpp/GMMStats.cpp", #"bob/learn/misc/cpp/IVectorMachine.cpp", #"bob/learn/misc/cpp/JFAMachine.cpp", #"bob/learn/misc/cpp/KMeansMachine.cpp", @@ -101,6 +101,7 @@ setup( Extension("bob.learn.misc._library", [ "bob/learn/misc/gaussian.cpp", + "bob/learn/misc/gmm_stats.cpp", "bob/learn/misc/main.cpp", ],