From 90f09f63432b1b9c1d48bd79a76a1a901cd41c36 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Thu, 22 Jan 2015 12:54:03 +0100 Subject: [PATCH] Binding GMMBaseTrainer class --- .../{GMMTrainer.cpp => GMMBaseTrainer.cpp} | 59 ++- bob/learn/misc/gmm_base_trainer.cpp | 348 ++++++++++++++++++ .../{GMMTrainer.h => GMMBaseTrainer.h} | 37 +- bob/learn/misc/kmeans_trainer.cpp | 2 +- bob/learn/misc/main.cpp | 5 +- bob/learn/misc/main.h | 14 + setup.py | 3 +- 7 files changed, 404 insertions(+), 64 deletions(-) rename bob/learn/misc/cpp/{GMMTrainer.cpp => GMMBaseTrainer.cpp} (53%) create mode 100644 bob/learn/misc/gmm_base_trainer.cpp rename bob/learn/misc/include/bob.learn.misc/{GMMTrainer.h => GMMBaseTrainer.h} (76%) diff --git a/bob/learn/misc/cpp/GMMTrainer.cpp b/bob/learn/misc/cpp/GMMBaseTrainer.cpp similarity index 53% rename from bob/learn/misc/cpp/GMMTrainer.cpp rename to bob/learn/misc/cpp/GMMBaseTrainer.cpp index e1eb8db..8210f2b 100644 --- a/bob/learn/misc/cpp/GMMTrainer.cpp +++ b/bob/learn/misc/cpp/GMMBaseTrainer.cpp @@ -5,39 +5,33 @@ * Copyright (C) Idiap Research Institute, Martigny, Switzerland */ -#include <bob.learn.misc/GMMTrainer.h> +#include <bob.learn.misc/GMMBaseTrainer.h> #include <bob.core/assert.h> #include <bob.core/check.h> -bob::learn::misc::GMMTrainer::GMMTrainer(const bool update_means, +bob::learn::misc::GMMBaseTrainer::GMMBaseTrainer(const bool update_means, const bool update_variances, const bool update_weights, const double mean_var_update_responsibilities_threshold): - bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<double,2> >(), m_update_means(update_means), m_update_variances(update_variances), m_update_weights(update_weights), m_mean_var_update_responsibilities_threshold(mean_var_update_responsibilities_threshold) -{ -} +{} -bob::learn::misc::GMMTrainer::GMMTrainer(const bob::learn::misc::GMMTrainer& b): - bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<double,2> >(b), +bob::learn::misc::GMMBaseTrainer::GMMBaseTrainer(const bob::learn::misc::GMMBaseTrainer& b): m_update_means(b.m_update_means), m_update_variances(b.m_update_variances), m_mean_var_update_responsibilities_threshold(b.m_mean_var_update_responsibilities_threshold) -{ -} +{} -bob::learn::misc::GMMTrainer::~GMMTrainer() -{ -} +bob::learn::misc::GMMBaseTrainer::~GMMBaseTrainer() +{} -void bob::learn::misc::GMMTrainer::initialize(bob::learn::misc::GMMMachine& gmm, - const blitz::Array<double,2>& data) +void bob::learn::misc::GMMBaseTrainer::initialize(bob::learn::misc::GMMMachine& gmm) { // Allocate memory for the sufficient statistics and initialise m_ss.resize(gmm.getNGaussians(),gmm.getNInputs()); } -void bob::learn::misc::GMMTrainer::eStep(bob::learn::misc::GMMMachine& gmm, +void bob::learn::misc::GMMBaseTrainer::eStep(bob::learn::misc::GMMMachine& gmm, const blitz::Array<double,2>& data) { m_ss.init(); @@ -45,23 +39,17 @@ void bob::learn::misc::GMMTrainer::eStep(bob::learn::misc::GMMMachine& gmm, gmm.accStatistics(data, m_ss); } -double bob::learn::misc::GMMTrainer::computeLikelihood(bob::learn::misc::GMMMachine& gmm) +double bob::learn::misc::GMMBaseTrainer::computeLikelihood(bob::learn::misc::GMMMachine& gmm) { return m_ss.log_likelihood / m_ss.T; } -void bob::learn::misc::GMMTrainer::finalize(bob::learn::misc::GMMMachine& gmm, - const blitz::Array<double,2>& data) -{ -} -bob::learn::misc::GMMTrainer& bob::learn::misc::GMMTrainer::operator= - (const bob::learn::misc::GMMTrainer &other) +bob::learn::misc::GMMBaseTrainer& bob::learn::misc::GMMBaseTrainer::operator= + (const bob::learn::misc::GMMBaseTrainer &other) { if (this != &other) { - bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine, - blitz::Array<double,2> >::operator=(other); m_ss = other.m_ss; m_update_means = other.m_update_means; m_update_variances = other.m_update_variances; @@ -71,32 +59,27 @@ bob::learn::misc::GMMTrainer& bob::learn::misc::GMMTrainer::operator= return *this; } -bool bob::learn::misc::GMMTrainer::operator== - (const bob::learn::misc::GMMTrainer &other) const +bool bob::learn::misc::GMMBaseTrainer::operator== + (const bob::learn::misc::GMMBaseTrainer &other) const { - return bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine, - blitz::Array<double,2> >::operator==(other) && - m_ss == other.m_ss && + return m_ss == other.m_ss && m_update_means == other.m_update_means && m_update_variances == other.m_update_variances && m_update_weights == other.m_update_weights && m_mean_var_update_responsibilities_threshold == other.m_mean_var_update_responsibilities_threshold; } -bool bob::learn::misc::GMMTrainer::operator!= - (const bob::learn::misc::GMMTrainer &other) const +bool bob::learn::misc::GMMBaseTrainer::operator!= + (const bob::learn::misc::GMMBaseTrainer &other) const { return !(this->operator==(other)); } -bool bob::learn::misc::GMMTrainer::is_similar_to - (const bob::learn::misc::GMMTrainer &other, const double r_epsilon, +bool bob::learn::misc::GMMBaseTrainer::is_similar_to + (const bob::learn::misc::GMMBaseTrainer &other, const double r_epsilon, const double a_epsilon) const { - return bob::learn::misc::EMTrainer<bob::learn::misc::GMMMachine, - blitz::Array<double,2> >::operator==(other) && - // TODO: use is similar to method for the accumulator m_ss - m_ss == other.m_ss && + return m_ss == other.m_ss && m_update_means == other.m_update_means && m_update_variances == other.m_update_variances && m_update_weights == other.m_update_weights && @@ -104,7 +87,7 @@ bool bob::learn::misc::GMMTrainer::is_similar_to other.m_mean_var_update_responsibilities_threshold, r_epsilon, a_epsilon); } -void bob::learn::misc::GMMTrainer::setGMMStats(const bob::learn::misc::GMMStats& stats) +void bob::learn::misc::GMMBaseTrainer::setGMMStats(const bob::learn::misc::GMMStats& stats) { bob::core::array::assertSameShape(m_ss.sumPx, stats.sumPx); m_ss = stats; diff --git a/bob/learn/misc/gmm_base_trainer.cpp b/bob/learn/misc/gmm_base_trainer.cpp new file mode 100644 index 0000000..5f133af --- /dev/null +++ b/bob/learn/misc/gmm_base_trainer.cpp @@ -0,0 +1,348 @@ +/** + * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + * @date Web 21 Jan 12:30:00 2015 + * + * @brief Python API for bob::learn::em + * + * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland + */ + +#include "main.h" + +/******************************************************************/ +/************ Constructor Section *********************************/ +/******************************************************************/ + +static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* converts PyObject to bool and returns false if object is NULL */ + +static auto GMMBaseTrainer_doc = bob::extension::ClassDoc( + BOB_EXT_MODULE_PREFIX ".GMMBaseTrainer", + "This class implements the E-step of the expectation-maximisation" + "algorithm for a :py:class:`bob.learn.misc.GMMMachine`" +).add_constructor( + bob::extension::FunctionDoc( + "__init__", + "Creates a GMMBaseTrainer", + "", + true + ) + .add_prototype("update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","") + .add_prototype("other","") + .add_prototype("","") + + .add_parameter("update_means", "bool", "Update means on each iteration") + .add_parameter("update_variances", "bool", "Update variances on each iteration") + .add_parameter("update_weights", "bool", "Update weights on each iteration") + .add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.") + .add_parameter("other", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A GMMBaseTrainer object to be copied.") +); + + + +static int PyBobLearnMiscGMMBaseTrainer_init_copy(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = GMMBaseTrainer_doc.kwlist(1); + PyBobLearnMiscGMMBaseTrainerObject* tt; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMBaseTrainer_Type, &tt)){ + GMMBaseTrainer_doc.print_usage(); + return -1; + } + + self->cxx.reset(new bob::learn::misc::GMMBaseTrainer(*tt->cxx)); + return 0; +} + + +static int PyBobLearnMiscGMMBaseTrainer_init_bool(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = GMMBaseTrainer_doc.kwlist(0); + PyObject* update_means = 0; + PyObject* update_variances = 0; + PyObject* update_weights = 0; + double mean_var_update_responsibilities_threshold = 0; //std::numeric_limits<double>::epsilon(); + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!d", kwlist, &PyBool_Type, &update_means, &PyBool_Type, + &update_variances, &PyBool_Type, &update_weights, &mean_var_update_responsibilities_threshold)){ + GMMBaseTrainer_doc.print_usage(); + return -1; + } + self->cxx.reset(new bob::learn::misc::GMMBaseTrainer(f(update_means), f(update_variances), f(update_weights), mean_var_update_responsibilities_threshold)); + return 0; +} + + +static int PyBobLearnMiscGMMBaseTrainer_init(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0); + + if (nargs==0){ //default initializer () + self->cxx.reset(new bob::learn::misc::GMMBaseTrainer()); + return 0; + } + else{ + //Reading the input argument + PyObject* arg = 0; + if (PyTuple_Size(args)) + arg = PyTuple_GET_ITEM(args, 0); + else { + PyObject* tmp = PyDict_Values(kwargs); + auto tmp_ = make_safe(tmp); + arg = PyList_GET_ITEM(tmp, 0); + } + + // If the constructor input is GMMBaseTrainer object + if (PyBobLearnMiscGMMBaseTrainer_Check(arg)) + return PyBobLearnMiscGMMBaseTrainer_init_copy(self, args, kwargs); + else + return PyBobLearnMiscGMMBaseTrainer_init_bool(self, args, kwargs); + } + + BOB_CATCH_MEMBER("cannot create GMMBaseTrainer_init_bool", 0) + return 0; +} + + +static void PyBobLearnMiscGMMBaseTrainer_delete(PyBobLearnMiscGMMBaseTrainerObject* self) { + self->cxx.reset(); + Py_TYPE(self)->tp_free((PyObject*)self); +} + + +int PyBobLearnMiscGMMBaseTrainer_Check(PyObject* o) { + return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscGMMBaseTrainer_Type)); +} + + +static PyObject* PyBobLearnMiscGMMBaseTrainer_RichCompare(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* other, int op) { + BOB_TRY + + if (!PyBobLearnMiscGMMBaseTrainer_Check(other)) { + PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name); + return 0; + } + auto other_ = reinterpret_cast<PyBobLearnMiscGMMBaseTrainerObject*>(other); + switch (op) { + case Py_EQ: + if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE; + case Py_NE: + if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE; + default: + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + BOB_CATCH_MEMBER("cannot compare GMMBaseTrainer objects", 0) +} + + +/******************************************************************/ +/************ Variables Section ***********************************/ +/******************************************************************/ + + +/***** gmm_stats *****/ +static auto gmm_stats = bob::extension::VariableDoc( + "gmm_stats", + ":py:class:`bob.learn.misc.GMMStats`", + "Get/Set GMMStats", + "" +); +PyObject* PyBobLearnMiscGMMBaseTrainer_getGMMStats(PyBobLearnMiscGMMBaseTrainerObject* self, void*){ + BOB_TRY + + //bob::learn::misc::GMMStats stats = self->cxx->getGMMStats(); + //boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = boost::make_shared(stats); + boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = 0; + + + //Allocating the correspondent python object + PyBobLearnMiscGMMStatsObject* retval = + (PyBobLearnMiscGMMStatsObject*)PyBobLearnMiscGMMStats_Type.tp_alloc(&PyBobLearnMiscGMMStats_Type, 0); + + retval->cxx = stats_shared; + + return Py_BuildValue("O",retval); + BOB_CATCH_MEMBER("GMMStats could not be read", 0) +} +int PyBobLearnMiscGMMBaseTrainer_setGMMStats(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* value, void*){ + BOB_TRY + + if (!PyBobLearnMiscGMMStats_Check(value)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMStats`", Py_TYPE(self)->tp_name, gmm_stats.name()); + return -1; + } + + PyBobLearnMiscGMMStatsObject* stats = 0; + PyArg_Parse(value, "O!", &PyBobLearnMiscGMMStats_Type,&stats); + + self->cxx->setGMMStats(*stats->cxx); + + return 0; + BOB_CATCH_MEMBER("gmm_stats could not be set", -1) +} + + + +static PyGetSetDef PyBobLearnMiscGMMBaseTrainer_getseters[] = { + { + gmm_stats.name(), + (getter)PyBobLearnMiscGMMBaseTrainer_getGMMStats, + (setter)PyBobLearnMiscGMMBaseTrainer_setGMMStats, + gmm_stats.doc(), + 0 + }, + {0} // Sentinel +}; + + +/******************************************************************/ +/************ Functions Section ***********************************/ +/******************************************************************/ + +/*** initialize ***/ +static auto initialize = bob::extension::FunctionDoc( + "initialize", + "Initialization before the EM steps", + "Instanciate :py:class:`bob.learn.misc.GMMStats`", + true +) +.add_prototype("gmm_machine") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object"); +static PyObject* PyBobLearnMiscGMMBaseTrainer_initialize(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = initialize.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE; + + self->cxx->initialize(*gmm_machine->cxx); + + BOB_CATCH_MEMBER("cannot perform the initialize method", 0) + + Py_RETURN_NONE; +} + + + +/*** eStep ***/ +static auto eStep = bob::extension::FunctionDoc( + "eStep", + "Calculates and saves statistics across the dataset," + "and saves these as m_ss. ", + + "Calculates the average log likelihood of the observations given the GMM," + "and returns this in average_log_likelihood." + "The statistics, m_ss, will be used in the mStep() that follows.", + + true +) +.add_prototype("gmm_machine,data") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object") +.add_parameter("data", "array_like <float, 2D>", "Input data"); +static PyObject* PyBobLearnMiscGMMBaseTrainer_eStep(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = eStep.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine; + PyBlitzArrayObject* data = 0; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine, + &PyBlitzArray_Converter, &data)) Py_RETURN_NONE; + auto data_ = make_safe(data); + + self->cxx->eStep(*gmm_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data)); + + BOB_CATCH_MEMBER("cannot perform the eStep method", 0) + + Py_RETURN_NONE; +} + + +/*** computeLikelihood ***/ +static auto compute_likelihood = bob::extension::FunctionDoc( + "compute_likelihood", + "This functions returns the average min (Square Euclidean) distance (average distance to the closest mean)", + 0, + true +) +.add_prototype("gmm_machine") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object"); +static PyObject* PyBobLearnMiscGMMBaseTrainer_compute_likelihood(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = compute_likelihood.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE; + + double value = self->cxx->computeLikelihood(*gmm_machine->cxx); + return Py_BuildValue("d", value); + + BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0) +} + + +static PyMethodDef PyBobLearnMiscGMMBaseTrainer_methods[] = { + { + initialize.name(), + (PyCFunction)PyBobLearnMiscGMMBaseTrainer_initialize, + METH_VARARGS|METH_KEYWORDS, + initialize.doc() + }, + { + eStep.name(), + (PyCFunction)PyBobLearnMiscGMMBaseTrainer_eStep, + METH_VARARGS|METH_KEYWORDS, + eStep.doc() + }, + { + compute_likelihood.name(), + (PyCFunction)PyBobLearnMiscGMMBaseTrainer_compute_likelihood, + METH_VARARGS|METH_KEYWORDS, + compute_likelihood.doc() + }, + {0} /* Sentinel */ +}; + + +/******************************************************************/ +/************ Module Section **************************************/ +/******************************************************************/ + +// Define the Gaussian type struct; will be initialized later +PyTypeObject PyBobLearnMiscGMMBaseTrainer_Type = { + PyVarObject_HEAD_INIT(0,0) + 0 +}; + +bool init_BobLearnMiscGMMBaseTrainer(PyObject* module) +{ + // initialize the type struct + PyBobLearnMiscGMMBaseTrainer_Type.tp_name = GMMBaseTrainer_doc.name(); + PyBobLearnMiscGMMBaseTrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscGMMBaseTrainerObject); + PyBobLearnMiscGMMBaseTrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT; + PyBobLearnMiscGMMBaseTrainer_Type.tp_doc = GMMBaseTrainer_doc.doc(); + + // set the functions + PyBobLearnMiscGMMBaseTrainer_Type.tp_new = PyType_GenericNew; + PyBobLearnMiscGMMBaseTrainer_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscGMMBaseTrainer_init); + PyBobLearnMiscGMMBaseTrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscGMMBaseTrainer_delete); + PyBobLearnMiscGMMBaseTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscGMMBaseTrainer_RichCompare); + PyBobLearnMiscGMMBaseTrainer_Type.tp_methods = PyBobLearnMiscGMMBaseTrainer_methods; + PyBobLearnMiscGMMBaseTrainer_Type.tp_getset = PyBobLearnMiscGMMBaseTrainer_getseters; + PyBobLearnMiscGMMBaseTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscGMMBaseTrainer_compute_likelihood); + + + // check that everything is fine + if (PyType_Ready(&PyBobLearnMiscGMMBaseTrainer_Type) < 0) return false; + + // add the type to the module + Py_INCREF(&PyBobLearnMiscGMMBaseTrainer_Type); + return PyModule_AddObject(module, "GMMBaseTrainer", (PyObject*)&PyBobLearnMiscGMMBaseTrainer_Type) >= 0; +} + diff --git a/bob/learn/misc/include/bob.learn.misc/GMMTrainer.h b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h similarity index 76% rename from bob/learn/misc/include/bob.learn.misc/GMMTrainer.h rename to bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h index fbbef1b..1cb0788 100644 --- a/bob/learn/misc/include/bob.learn.misc/GMMTrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h @@ -8,10 +8,9 @@ * Copyright (C) Idiap Research Institute, Martigny, Switzerland */ -#ifndef BOB_LEARN_MISC_GMMTRAINER_H -#define BOB_LEARN_MISC_GMMTRAINER_H +#ifndef BOB_LEARN_MISC_GMMBASETRAINER_H +#define BOB_LEARN_MISC_GMMBASETRAINER_H -#include <bob.learn.misc/EMTrainer.h> #include <bob.learn.misc/GMMMachine.h> #include <bob.learn.misc/GMMStats.h> #include <limits> @@ -24,13 +23,13 @@ namespace bob { namespace learn { namespace misc { * @details See Section 9.2.2 of Bishop, * "Pattern recognition and machine learning", 2006 */ -class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<double,2> > +class GMMBaseTrainer { public: /** * @brief Default constructor */ - GMMTrainer(const bool update_means=true, + GMMBaseTrainer(const bool update_means=true, const bool update_variances=false, const bool update_weights=false, const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon()); @@ -38,18 +37,17 @@ class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<do /** * @brief Copy constructor */ - GMMTrainer(const GMMTrainer& other); + GMMBaseTrainer(const GMMBaseTrainer& other); /** * @brief Destructor */ - virtual ~GMMTrainer(); + virtual ~GMMBaseTrainer(); /** * @brief Initialization before the EM steps */ - virtual void initialize(bob::learn::misc::GMMMachine& gmm, - const blitz::Array<double,2>& data); + virtual void initialize(bob::learn::misc::GMMMachine& gmm); /** * @brief Calculates and saves statistics across the dataset, @@ -69,38 +67,33 @@ class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<do */ virtual double computeLikelihood(bob::learn::misc::GMMMachine& gmm); - /** - * @brief Finalization after the EM steps - */ - virtual void finalize(bob::learn::misc::GMMMachine& gmm, - const blitz::Array<double,2>& data); /** - * @brief Assigns from a different GMMTrainer + * @brief Assigns from a different GMMBaseTrainer */ - GMMTrainer& operator=(const GMMTrainer &other); + GMMBaseTrainer& operator=(const GMMBaseTrainer &other); /** * @brief Equal to */ - bool operator==(const GMMTrainer& b) const; + bool operator==(const GMMBaseTrainer& b) const; /** * @brief Not equal to */ - bool operator!=(const GMMTrainer& b) const; + bool operator!=(const GMMBaseTrainer& b) const; /** * @brief Similar to */ - bool is_similar_to(const GMMTrainer& b, const double r_epsilon=1e-5, + bool is_similar_to(const GMMBaseTrainer& b, const double r_epsilon=1e-5, const double a_epsilon=1e-8) const; /** * @brief Returns the internal GMM statistics. Useful to parallelize the * E-step */ - const bob::learn::misc::GMMStats& getGMMStats() const + const bob::learn::misc::GMMStats getGMMStats() const { return m_ss; } /** @@ -109,7 +102,7 @@ class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<do */ void setGMMStats(const bob::learn::misc::GMMStats& stats); - protected: + private: /** * These are the sufficient statistics, calculated during the * E-step and used during the M-step @@ -142,4 +135,4 @@ class GMMTrainer: public EMTrainer<bob::learn::misc::GMMMachine, blitz::Array<do } } } // namespaces -#endif // BOB_LEARN_MISC_GMMTRAINER_H +#endif // BOB_LEARN_MISC_GMMBASETRAINER_H diff --git a/bob/learn/misc/kmeans_trainer.cpp b/bob/learn/misc/kmeans_trainer.cpp index 2c1111b..cbaac80 100644 --- a/bob/learn/misc/kmeans_trainer.cpp +++ b/bob/learn/misc/kmeans_trainer.cpp @@ -540,7 +540,7 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module) PyBobLearnMiscKMeansTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscKMeansTrainer_RichCompare); PyBobLearnMiscKMeansTrainer_Type.tp_methods = PyBobLearnMiscKMeansTrainer_methods; PyBobLearnMiscKMeansTrainer_Type.tp_getset = PyBobLearnMiscKMeansTrainer_getseters; - PyBobLearnMiscGMMMachine_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscKMeansTrainer_compute_likelihood); + PyBobLearnMiscKMeansTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscKMeansTrainer_compute_likelihood); // check that everything is fine diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp index fb8b2a5..0431832 100644 --- a/bob/learn/misc/main.cpp +++ b/bob/learn/misc/main.cpp @@ -43,8 +43,9 @@ static PyObject* create_module (void) { if (!init_BobLearnMiscGaussian(module)) return 0; if (!init_BobLearnMiscGMMStats(module)) return 0; if (!init_BobLearnMiscGMMMachine(module)) return 0; - if (!init_BobLearnMiscKMeansMachine(module)) return 0; - if (!init_BobLearnMiscKMeansTrainer(module)) return 0; + if (!init_BobLearnMiscKMeansMachine(module)) return 0; + if (!init_BobLearnMiscKMeansTrainer(module)) return 0; + if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0; static void* PyBobLearnMisc_API[PyBobLearnMisc_API_pointers]; diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h index a6efaa7..cffea0e 100644 --- a/bob/learn/misc/main.h +++ b/bob/learn/misc/main.h @@ -23,6 +23,9 @@ #include <bob.learn.misc/KMeansMachine.h> #include <bob.learn.misc/KMeansTrainer.h> +#include <bob.learn.misc/GMMBaseTrainer.h> + + #if PY_VERSION_HEX >= 0x03000000 #define PyInt_Check PyLong_Check #define PyInt_AS_LONG PyLong_AS_LONG @@ -116,5 +119,16 @@ bool init_BobLearnMiscKMeansTrainer(PyObject* module); int PyBobLearnMiscKMeansTrainer_Check(PyObject* o); +// GMMBaseTrainer +typedef struct { + PyObject_HEAD + boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> cxx; +} PyBobLearnMiscGMMBaseTrainerObject; + +extern PyTypeObject PyBobLearnMiscGMMBaseTrainer_Type; +bool init_BobLearnMiscGMMBaseTrainer(PyObject* module); +int PyBobLearnMiscGMMBaseTrainer_Check(PyObject* o); + + #endif // BOB_LEARN_EM_MAIN_H diff --git a/setup.py b/setup.py index 57f2a77..9bc5e23 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( #"bob/learn/misc/cpp/ZTNorm.cpp", #"bob/learn/misc/cpp/EMPCATrainer.cpp", - #"bob/learn/misc/cpp/GMMTrainer.cpp", + "bob/learn/misc/cpp/GMMBaseTrainer.cpp", #"bob/learn/misc/cpp/IVectorTrainer.cpp", #"bob/learn/misc/cpp/JFATrainer.cpp", "bob/learn/misc/cpp/KMeansTrainer.cpp", @@ -105,6 +105,7 @@ setup( "bob/learn/misc/gmm_machine.cpp", "bob/learn/misc/kmeans_machine.cpp", "bob/learn/misc/kmeans_trainer.cpp", + "bob/learn/misc/gmm_base_trainer.cpp", "bob/learn/misc/main.cpp", ], -- GitLab