From 3aad63f1699d8f14da4a8b7ad45eb415a7c3bbb8 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Sat, 7 Feb 2015 22:14:52 +0100 Subject: [PATCH] Reorganized the MAP and ML trainers --- bob/learn/misc/MAP_gmm_trainer.cpp | 175 +++++++++------- bob/learn/misc/ML_gmm_trainer.cpp | 187 ++++++++++-------- bob/learn/misc/__MAP_gmm_trainer__.py | 22 ++- bob/learn/misc/__ML_gmm_trainer__.py | 20 +- bob/learn/misc/cpp/MAP_GMMTrainer.cpp | 34 ++-- bob/learn/misc/cpp/ML_GMMTrainer.cpp | 25 ++- .../include/bob.learn.misc/GMMBaseTrainer.h | 12 +- .../include/bob.learn.misc/MAP_GMMTrainer.h | 49 +++-- .../include/bob.learn.misc/ML_GMMTrainer.h | 41 ++-- bob/learn/misc/main.cpp | 2 +- bob/learn/misc/main.h | 5 +- bob/learn/misc/test_em.py | 14 +- setup.py | 2 +- 13 files changed, 357 insertions(+), 231 deletions(-) diff --git a/bob/learn/misc/MAP_gmm_trainer.cpp b/bob/learn/misc/MAP_gmm_trainer.cpp index 1ffa4d2..f3a3f4c 100644 --- a/bob/learn/misc/MAP_gmm_trainer.cpp +++ b/bob/learn/misc/MAP_gmm_trainer.cpp @@ -13,6 +13,7 @@ /************ Constructor Section *********************************/ /******************************************************************/ +static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* converts PyObject to bool and returns false if object is NULL */ static auto MAP_GMMTrainer_doc = bob::extension::ClassDoc( BOB_EXT_MODULE_PREFIX ".MAP_GMMTrainer", @@ -24,18 +25,22 @@ static auto MAP_GMMTrainer_doc = bob::extension::ClassDoc( "", true ) - - - .add_prototype("gmm_base_trainer,prior_gmm,relevance_factor","") - .add_prototype("gmm_base_trainer,prior_gmm,alpha","") + + .add_prototype("prior_gmm,relevance_factor, update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","") + .add_prototype("prior_gmm,alpha, update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","") .add_prototype("other","") .add_prototype("","") - .add_parameter("gmm_base_trainer", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A GMMBaseTrainer object.") .add_parameter("prior_gmm", ":py:class:`bob.learn.misc.GMMMachine`", "The prior GMM to be adapted (Universal Backgroud Model UBM).") .add_parameter("reynolds_adaptation", "bool", "Will use the Reynolds adaptation procedure? See Eq (14) from [Reynolds2000]_") .add_parameter("relevance_factor", "double", "If set the reynolds_adaptation parameters, will apply the Reynolds Adaptation procedure. See Eq (14) from [Reynolds2000]_") .add_parameter("alpha", "double", "Set directly the alpha parameter (Eq (14) from [Reynolds2000]_), ignoring zeroth order statistics as a weighting factor.") + + .add_parameter("update_means", "bool", "Update means on each iteration") + .add_parameter("update_variances", "bool", "Update variances on each iteration") + .add_parameter("update_weights", "bool", "Update weights on each iteration") + .add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.") + .add_parameter("other", ":py:class:`bob.learn.misc.MAP_GMMTrainer`", "A MAP_GMMTrainer object to be copied.") ); @@ -59,43 +64,54 @@ static int PyBobLearnMiscMAPGMMTrainer_init_base_trainer(PyBobLearnMiscMAPGMMTra char** kwlist1 = MAP_GMMTrainer_doc.kwlist(0); char** kwlist2 = MAP_GMMTrainer_doc.kwlist(1); - PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer; PyBobLearnMiscGMMMachineObject* gmm_machine; bool reynolds_adaptation = false; double alpha = 0.5; double relevance_factor = 4.0; double aux = 0; - PyObject* keyword_relevance_factor = Py_BuildValue("s", kwlist1[2]); - PyObject* keyword_alpha = Py_BuildValue("s", kwlist2[2]); + PyObject* update_means = 0; + PyObject* update_variances = 0; + PyObject* update_weights = 0; + double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(); + + PyObject* keyword_relevance_factor = Py_BuildValue("s", kwlist1[1]); + PyObject* keyword_alpha = Py_BuildValue("s", kwlist2[1]); //Here we have to select which keyword argument to read - if (kwargs && PyDict_Contains(kwargs, keyword_relevance_factor) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|d", kwlist1, - &PyBobLearnMiscGMMBaseTrainer_Type, &gmm_base_trainer, + if (kwargs && PyDict_Contains(kwargs, keyword_relevance_factor) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!dO!|O!O!d", kwlist1, &PyBobLearnMiscGMMMachine_Type, &gmm_machine, - &aux))) + &aux, + &PyBool_Type, &update_means, + &PyBool_Type, &update_variances, + &PyBool_Type, &update_weights, + &mean_var_update_responsibilities_threshold))) reynolds_adaptation = true; - else if (kwargs && PyDict_Contains(kwargs, keyword_alpha) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|d", kwlist2, - &PyBobLearnMiscGMMBaseTrainer_Type, &gmm_base_trainer, + else if (kwargs && PyDict_Contains(kwargs, keyword_alpha) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!dO!|O!O!d", kwlist2, &PyBobLearnMiscGMMMachine_Type, &gmm_machine, - &aux))) + &aux, + &PyBool_Type, &update_means, + &PyBool_Type, &update_variances, + &PyBool_Type, &update_weights, + &mean_var_update_responsibilities_threshold))) reynolds_adaptation = false; else{ - PyErr_Format(PyExc_RuntimeError, "%s. The third argument must be a keyword argument.", Py_TYPE(self)->tp_name); + PyErr_Format(PyExc_RuntimeError, "%s. The second argument must be a keyword argument.", Py_TYPE(self)->tp_name); MAP_GMMTrainer_doc.print_usage(); return -1; } - - if (reynolds_adaptation) relevance_factor = aux; else alpha = aux; - self->cxx.reset(new bob::learn::misc::MAP_GMMTrainer(gmm_base_trainer->cxx, gmm_machine->cxx, reynolds_adaptation,relevance_factor, alpha)); + self->cxx.reset(new bob::learn::misc::MAP_GMMTrainer(f(update_means), f(update_variances), f(update_weights), + mean_var_update_responsibilities_threshold, + reynolds_adaptation,relevance_factor, alpha, gmm_machine->cxx)); return 0; + } @@ -151,47 +167,6 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_RichCompare(PyBobLearnMiscMAPGMMTra /************ Variables Section ***********************************/ /******************************************************************/ - -/***** gmm_base_trainer *****/ -static auto gmm_base_trainer = bob::extension::VariableDoc( - "gmm_base_trainer", - ":py:class:`bob.learn.misc.GMMBaseTrainer`", - "This class that implements the E-step of the expectation-maximisation algorithm.", - "" -); -PyObject* PyBobLearnMiscMAPGMMTrainer_getGMMBaseTrainer(PyBobLearnMiscMAPGMMTrainerObject* self, void*){ - BOB_TRY - - boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer = self->cxx->getGMMBaseTrainer(); - - //Allocating the correspondent python object - PyBobLearnMiscGMMBaseTrainerObject* retval = - (PyBobLearnMiscGMMBaseTrainerObject*)PyBobLearnMiscGMMBaseTrainer_Type.tp_alloc(&PyBobLearnMiscGMMBaseTrainer_Type, 0); - - retval->cxx = gmm_base_trainer; - - return Py_BuildValue("O",retval); - BOB_CATCH_MEMBER("GMMBaseTrainer could not be read", 0) -} -int PyBobLearnMiscMAPGMMTrainer_setGMMBaseTrainer(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* value, void*){ - BOB_TRY - - if (!PyBobLearnMiscGMMBaseTrainer_Check(value)){ - PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMBaseTrainer`", Py_TYPE(self)->tp_name, gmm_base_trainer.name()); - return -1; - } - - PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer = 0; - PyArg_Parse(value, "O!", &PyBobLearnMiscGMMBaseTrainer_Type,&gmm_base_trainer); - - self->cxx->setGMMBaseTrainer(gmm_base_trainer->cxx); - - return 0; - BOB_CATCH_MEMBER("gmm_base_trainer could not be set", -1) -} - - - /***** relevance_factor *****/ static auto relevance_factor = bob::extension::VariableDoc( "relevance_factor", @@ -246,13 +221,6 @@ int PyBobLearnMiscMAPGMMTrainer_setAlpha(PyBobLearnMiscMAPGMMTrainerObject* self static PyGetSetDef PyBobLearnMiscMAPGMMTrainer_getseters[] = { - { - gmm_base_trainer.name(), - (getter)PyBobLearnMiscMAPGMMTrainer_getGMMBaseTrainer, - (setter)PyBobLearnMiscMAPGMMTrainer_setGMMBaseTrainer, - gmm_base_trainer.doc(), - 0 - }, { alpha.name(), (getter)PyBobLearnMiscMAPGMMTrainer_getAlpha, @@ -304,6 +272,40 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_initialize(PyBobLearnMiscMAPGMMTrai } +/*** eStep ***/ +static auto eStep = bob::extension::FunctionDoc( + "eStep", + "Calculates and saves statistics across the dataset," + "and saves these as m_ss. ", + + "Calculates the average log likelihood of the observations given the GMM," + "and returns this in average_log_likelihood." + "The statistics, m_ss, will be used in the mStep() that follows.", + + true +) +.add_prototype("gmm_machine,data") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object") +.add_parameter("data", "array_like <float, 2D>", "Input data"); +static PyObject* PyBobLearnMiscMAPGMMTrainer_eStep(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = eStep.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine; + PyBlitzArrayObject* data = 0; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine, + &PyBlitzArray_Converter, &data)) Py_RETURN_NONE; + auto data_ = make_safe(data); + + self->cxx->eStep(*gmm_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data)); + + BOB_CATCH_MEMBER("cannot perform the eStep method", 0) + + Py_RETURN_NONE; +} + /*** mStep ***/ static auto mStep = bob::extension::FunctionDoc( @@ -335,6 +337,31 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_mStep(PyBobLearnMiscMAPGMMTrainerOb } +/*** computeLikelihood ***/ +static auto compute_likelihood = bob::extension::FunctionDoc( + "compute_likelihood", + "This functions returns the average min (Square Euclidean) distance (average distance to the closest mean)", + 0, + true +) +.add_prototype("gmm_machine") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object"); +static PyObject* PyBobLearnMiscMAPGMMTrainer_compute_likelihood(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = compute_likelihood.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE; + + double value = self->cxx->computeLikelihood(*gmm_machine->cxx); + return Py_BuildValue("d", value); + + BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0) +} + + static PyMethodDef PyBobLearnMiscMAPGMMTrainer_methods[] = { { @@ -343,13 +370,25 @@ static PyMethodDef PyBobLearnMiscMAPGMMTrainer_methods[] = { METH_VARARGS|METH_KEYWORDS, initialize.doc() }, + { + eStep.name(), + (PyCFunction)PyBobLearnMiscMAPGMMTrainer_eStep, + METH_VARARGS|METH_KEYWORDS, + eStep.doc() + }, { mStep.name(), (PyCFunction)PyBobLearnMiscMAPGMMTrainer_mStep, METH_VARARGS|METH_KEYWORDS, mStep.doc() }, - + { + compute_likelihood.name(), + (PyCFunction)PyBobLearnMiscMAPGMMTrainer_compute_likelihood, + METH_VARARGS|METH_KEYWORDS, + compute_likelihood.doc() + }, + {0} /* Sentinel */ }; @@ -379,7 +418,7 @@ bool init_BobLearnMiscMAPGMMTrainer(PyObject* module) PyBobLearnMiscMAPGMMTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscMAPGMMTrainer_RichCompare); PyBobLearnMiscMAPGMMTrainer_Type.tp_methods = PyBobLearnMiscMAPGMMTrainer_methods; PyBobLearnMiscMAPGMMTrainer_Type.tp_getset = PyBobLearnMiscMAPGMMTrainer_getseters; - //PyBobLearnMiscMAPGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMAPGMMTrainer_compute_likelihood); + PyBobLearnMiscMAPGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMAPGMMTrainer_compute_likelihood); // check that everything is fine diff --git a/bob/learn/misc/ML_gmm_trainer.cpp b/bob/learn/misc/ML_gmm_trainer.cpp index f9a4598..ff72609 100644 --- a/bob/learn/misc/ML_gmm_trainer.cpp +++ b/bob/learn/misc/ML_gmm_trainer.cpp @@ -13,6 +13,8 @@ /************ Constructor Section *********************************/ /******************************************************************/ +static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* converts PyObject to bool and returns false if object is NULL */ + static auto ML_GMMTrainer_doc = bob::extension::ClassDoc( BOB_EXT_MODULE_PREFIX ".ML_GMMTrainer", "This class implements the maximum likelihood M-step of the expectation-maximisation algorithm for a GMM Machine." @@ -23,11 +25,16 @@ static auto ML_GMMTrainer_doc = bob::extension::ClassDoc( "", true ) - .add_prototype("gmm_base_trainer","") + .add_prototype("update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","") .add_prototype("other","") .add_prototype("","") - .add_parameter("gmm_base_trainer", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A set GMMBaseTrainer object.") + .add_parameter("update_means", "bool", "Update means on each iteration") + .add_parameter("update_variances", "bool", "Update variances on each iteration") + .add_parameter("update_weights", "bool", "Update weights on each iteration") + .add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.") + + .add_parameter("other", ":py:class:`bob.learn.misc.ML_GMMTrainer`", "A ML_GMMTrainer object to be copied.") ); @@ -48,15 +55,24 @@ static int PyBobLearnMiscMLGMMTrainer_init_copy(PyBobLearnMiscMLGMMTrainerObject static int PyBobLearnMiscMLGMMTrainer_init_base_trainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { - char** kwlist = ML_GMMTrainer_doc.kwlist(1); - PyBobLearnMiscGMMBaseTrainerObject* o; + char** kwlist = ML_GMMTrainer_doc.kwlist(0); - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMBaseTrainer_Type, &o)){ + PyObject* update_means = 0; + PyObject* update_variances = 0; + PyObject* update_weights = 0; + double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(); + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!d", kwlist, + &PyBool_Type, &update_means, + &PyBool_Type, &update_variances, + &PyBool_Type, &update_weights, + &mean_var_update_responsibilities_threshold)){ ML_GMMTrainer_doc.print_usage(); return -1; } - self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(o->cxx)); + self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(f(update_means), f(update_variances), f(update_weights), + mean_var_update_responsibilities_threshold)); return 0; } @@ -65,31 +81,24 @@ static int PyBobLearnMiscMLGMMTrainer_init_base_trainer(PyBobLearnMiscMLGMMTrain static int PyBobLearnMiscMLGMMTrainer_init(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { BOB_TRY - int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0); - - if (nargs==1){ //default initializer () - //Reading the input argument - PyObject* arg = 0; - if (PyTuple_Size(args)) - arg = PyTuple_GET_ITEM(args, 0); - else { - PyObject* tmp = PyDict_Values(kwargs); - auto tmp_ = make_safe(tmp); - arg = PyList_GET_ITEM(tmp, 0); - } - - // If the constructor input is GMMBaseTrainer object - if (PyBobLearnMiscGMMBaseTrainer_Check(arg)) - return PyBobLearnMiscMLGMMTrainer_init_base_trainer(self, args, kwargs); - else - return PyBobLearnMiscMLGMMTrainer_init_copy(self, args, kwargs); - } - else{ - PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 1, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs); - ML_GMMTrainer_doc.print_usage(); - return -1; + //Reading the input argument + PyObject* arg = 0; + if (PyTuple_Size(args)) + arg = PyTuple_GET_ITEM(args, 0); + else { + PyObject* tmp = PyDict_Values(kwargs); + auto tmp_ = make_safe(tmp); + arg = PyList_GET_ITEM(tmp, 0); } + // If the constructor input is GMMBaseTrainer object + if (PyBobLearnMiscMLGMMTrainer_Check(arg)) + return PyBobLearnMiscMLGMMTrainer_init_copy(self, args, kwargs); + else + return PyBobLearnMiscMLGMMTrainer_init_base_trainer(self, args, kwargs); + + + BOB_CATCH_MEMBER("cannot create GMMBaseTrainer_init_bool", 0) return 0; } @@ -131,54 +140,7 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_RichCompare(PyBobLearnMiscMLGMMTrain /************ Variables Section ***********************************/ /******************************************************************/ - -/***** gmm_base_trainer *****/ -static auto gmm_base_trainer = bob::extension::VariableDoc( - "gmm_base_trainer", - ":py:class:`bob.learn.misc.GMMBaseTrainer`", - "This class that implements the E-step of the expectation-maximisation algorithm.", - "" -); -PyObject* PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, void*){ - BOB_TRY - - boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer = self->cxx->getGMMBaseTrainer(); - - //Allocating the correspondent python object - PyBobLearnMiscGMMBaseTrainerObject* retval = - (PyBobLearnMiscGMMBaseTrainerObject*)PyBobLearnMiscGMMBaseTrainer_Type.tp_alloc(&PyBobLearnMiscGMMBaseTrainer_Type, 0); - - retval->cxx = gmm_base_trainer; - - return Py_BuildValue("O",retval); - BOB_CATCH_MEMBER("GMMBaseTrainer could not be read", 0) -} -int PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* value, void*){ - BOB_TRY - - if (!PyBobLearnMiscGMMBaseTrainer_Check(value)){ - PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMBaseTrainer`", Py_TYPE(self)->tp_name, gmm_base_trainer.name()); - return -1; - } - - PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer = 0; - PyArg_Parse(value, "O!", &PyBobLearnMiscGMMBaseTrainer_Type,&gmm_base_trainer); - - self->cxx->setGMMBaseTrainer(gmm_base_trainer->cxx); - - return 0; - BOB_CATCH_MEMBER("gmm_base_trainer could not be set", -1) -} - - static PyGetSetDef PyBobLearnMiscMLGMMTrainer_getseters[] = { - { - gmm_base_trainer.name(), - (getter)PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer, - (setter)PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer, - gmm_base_trainer.doc(), - 0 - }, {0} // Sentinel }; @@ -215,6 +177,40 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_initialize(PyBobLearnMiscMLGMMTraine } +/*** eStep ***/ +static auto eStep = bob::extension::FunctionDoc( + "eStep", + "Calculates and saves statistics across the dataset," + "and saves these as m_ss. ", + + "Calculates the average log likelihood of the observations given the GMM," + "and returns this in average_log_likelihood." + "The statistics, m_ss, will be used in the mStep() that follows.", + + true +) +.add_prototype("gmm_machine,data") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object") +.add_parameter("data", "array_like <float, 2D>", "Input data"); +static PyObject* PyBobLearnMiscMLGMMTrainer_eStep(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = eStep.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine; + PyBlitzArrayObject* data = 0; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine, + &PyBlitzArray_Converter, &data)) Py_RETURN_NONE; + auto data_ = make_safe(data); + + self->cxx->eStep(*gmm_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data)); + + BOB_CATCH_MEMBER("cannot perform the eStep method", 0) + + Py_RETURN_NONE; +} + /*** mStep ***/ static auto mStep = bob::extension::FunctionDoc( @@ -246,6 +242,31 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_mStep(PyBobLearnMiscMLGMMTrainerObje } +/*** computeLikelihood ***/ +static auto compute_likelihood = bob::extension::FunctionDoc( + "compute_likelihood", + "This functions returns the average min (Square Euclidean) distance (average distance to the closest mean)", + 0, + true +) +.add_prototype("gmm_machine") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object"); +static PyObject* PyBobLearnMiscMLGMMTrainer_compute_likelihood(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = compute_likelihood.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE; + + double value = self->cxx->computeLikelihood(*gmm_machine->cxx); + return Py_BuildValue("d", value); + + BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0) +} + + static PyMethodDef PyBobLearnMiscMLGMMTrainer_methods[] = { { @@ -254,12 +275,24 @@ static PyMethodDef PyBobLearnMiscMLGMMTrainer_methods[] = { METH_VARARGS|METH_KEYWORDS, initialize.doc() }, + { + eStep.name(), + (PyCFunction)PyBobLearnMiscMLGMMTrainer_eStep, + METH_VARARGS|METH_KEYWORDS, + eStep.doc() + }, { mStep.name(), (PyCFunction)PyBobLearnMiscMLGMMTrainer_mStep, METH_VARARGS|METH_KEYWORDS, mStep.doc() }, + { + compute_likelihood.name(), + (PyCFunction)PyBobLearnMiscMLGMMTrainer_compute_likelihood, + METH_VARARGS|METH_KEYWORDS, + compute_likelihood.doc() + }, {0} /* Sentinel */ }; @@ -289,7 +322,7 @@ bool init_BobLearnMiscMLGMMTrainer(PyObject* module) PyBobLearnMiscMLGMMTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscMLGMMTrainer_RichCompare); PyBobLearnMiscMLGMMTrainer_Type.tp_methods = PyBobLearnMiscMLGMMTrainer_methods; PyBobLearnMiscMLGMMTrainer_Type.tp_getset = PyBobLearnMiscMLGMMTrainer_getseters; - //PyBobLearnMiscMLGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMLGMMTrainer_compute_likelihood); + PyBobLearnMiscMLGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMLGMMTrainer_compute_likelihood); // check that everything is fine diff --git a/bob/learn/misc/__MAP_gmm_trainer__.py b/bob/learn/misc/__MAP_gmm_trainer__.py index 18f215d..4258b08 100644 --- a/bob/learn/misc/__MAP_gmm_trainer__.py +++ b/bob/learn/misc/__MAP_gmm_trainer__.py @@ -11,13 +11,17 @@ import numpy # define the class class MAP_GMMTrainer(_MAP_GMMTrainer): - def __init__(self, gmm_base_trainer, prior_gmm, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True, **kwargs): + def __init__(self, prior_gmm, update_means=True, update_variances=False, update_weights=False, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True, **kwargs): """ :py:class:`bob.learn.misc.MAP_GMMTrainer` constructor Keyword Parameters: - gmm_base_trainer - The base trainer (:py:class:`bob.learn.misc.GMMBaseTrainer`) + update_means + + update_variances + + update_weights + prior_gmm A :py:class:`bob.learn.misc.GMMMachine` to be adapted convergence_threshold @@ -34,10 +38,10 @@ class MAP_GMMTrainer(_MAP_GMMTrainer): if kwargs.get('alpha')!=None: alpha = kwargs.get('alpha') - _MAP_GMMTrainer.__init__(self, gmm_base_trainer, prior_gmm, alpha=alpha) + _MAP_GMMTrainer.__init__(self, prior_gmm,alpha=alpha, update_means=update_means, update_variances=update_variances,update_weights=update_weights) else: relevance_factor = kwargs.get('relevance_factor') - _MAP_GMMTrainer.__init__(self, gmm_base_trainer, prior_gmm, relevance_factor=relevance_factor) + _MAP_GMMTrainer.__init__(self, prior_gmm, relevance_factor=relevance_factor, update_means=update_means, update_variances=update_variances,update_weights=update_weights) self.convergence_threshold = convergence_threshold self.max_iterations = max_iterations @@ -67,10 +71,10 @@ class MAP_GMMTrainer(_MAP_GMMTrainer): #eStep - self.gmm_base_trainer.eStep(gmm_machine, data); + self.eStep(gmm_machine, data); if(self.converge_by_likelihood): - average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine); + average_output = self.compute_likelihood(gmm_machine); for i in range(self.max_iterations): #saves average output from last iteration @@ -80,11 +84,11 @@ class MAP_GMMTrainer(_MAP_GMMTrainer): self.mStep(gmm_machine); #eStep - self.gmm_base_trainer.eStep(gmm_machine, data); + self.eStep(gmm_machine, data); #Computes log likelihood if required if(self.converge_by_likelihood): - average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine); + average_output = self.compute_likelihood(gmm_machine); #Terminates if converged (and likelihood computation is set) if abs((average_output_previous - average_output)/average_output_previous) <= self.convergence_threshold: diff --git a/bob/learn/misc/__ML_gmm_trainer__.py b/bob/learn/misc/__ML_gmm_trainer__.py index 53c5a75..93a3c6c 100644 --- a/bob/learn/misc/__ML_gmm_trainer__.py +++ b/bob/learn/misc/__ML_gmm_trainer__.py @@ -11,13 +11,17 @@ import numpy # define the class class ML_GMMTrainer(_ML_GMMTrainer): - def __init__(self, gmm_base_trainer, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True): + def __init__(self, update_means=True, update_variances=False, update_weights=False, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True): """ :py:class:bob.learn.misc.ML_GMMTrainer constructor Keyword Parameters: - gmm_base_trainer - The base trainer (:py:class:`bob.learn.misc.GMMBaseTrainer` + update_means + + update_variances + + update_weights + convergence_threshold Convergence threshold max_iterations @@ -27,7 +31,7 @@ class ML_GMMTrainer(_ML_GMMTrainer): """ - _ML_GMMTrainer.__init__(self, gmm_base_trainer) + _ML_GMMTrainer.__init__(self, update_means=update_means, update_variances=update_variances, update_weights=update_weights) self.convergence_threshold = convergence_threshold self.max_iterations = max_iterations self.converge_by_likelihood = converge_by_likelihood @@ -53,10 +57,10 @@ class ML_GMMTrainer(_ML_GMMTrainer): #eStep - self.gmm_base_trainer.eStep(gmm_machine, data); + self.eStep(gmm_machine, data); if(self.converge_by_likelihood): - average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine); + average_output = self.compute_likelihood(gmm_machine); for i in range(self.max_iterations): #saves average output from last iteration @@ -66,11 +70,11 @@ class ML_GMMTrainer(_ML_GMMTrainer): self.mStep(gmm_machine); #eStep - self.gmm_base_trainer.eStep(gmm_machine, data); + self.eStep(gmm_machine, data); #Computes log likelihood if required if(self.converge_by_likelihood): - average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine); + average_output = self.compute_likelihood(gmm_machine); #Terminates if converged (and likelihood computation is set) if abs((average_output_previous - average_output)/average_output_previous) <= self.convergence_threshold: diff --git a/bob/learn/misc/cpp/MAP_GMMTrainer.cpp b/bob/learn/misc/cpp/MAP_GMMTrainer.cpp index 0645c2c..d20b150 100644 --- a/bob/learn/misc/cpp/MAP_GMMTrainer.cpp +++ b/bob/learn/misc/cpp/MAP_GMMTrainer.cpp @@ -8,8 +8,18 @@ #include <bob.learn.misc/MAP_GMMTrainer.h> #include <bob.core/check.h> -bob::learn::misc::MAP_GMMTrainer::MAP_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer, boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm, const bool reynolds_adaptation, const double relevance_factor, const double alpha): - m_gmm_base_trainer(gmm_base_trainer), +bob::learn::misc::MAP_GMMTrainer::MAP_GMMTrainer( + const bool update_means, + const bool update_variances, + const bool update_weights, + const double mean_var_update_responsibilities_threshold, + + const bool reynolds_adaptation, + const double relevance_factor, + const double alpha, + boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm): + + m_gmm_base_trainer(update_means, update_variances, update_weights, mean_var_update_responsibilities_threshold), m_prior_gmm(prior_gmm) { m_reynolds_adaptation = reynolds_adaptation; @@ -37,7 +47,7 @@ void bob::learn::misc::MAP_GMMTrainer::initialize(bob::learn::misc::GMMMachine& throw std::runtime_error("MAP_GMMTrainer: Prior GMM distribution has not been set"); // Allocate memory for the sufficient statistics and initialise - m_gmm_base_trainer->initialize(gmm); + m_gmm_base_trainer.initialize(gmm); const size_t n_gaussians = gmm.getNGaussians(); // TODO: check size? @@ -78,13 +88,13 @@ void bob::learn::misc::MAP_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm) if (!m_reynolds_adaptation) m_cache_alpha = m_alpha; else - m_cache_alpha = m_gmm_base_trainer->getGMMStats().n(i) / (m_gmm_base_trainer->getGMMStats().n(i) + m_relevance_factor); + m_cache_alpha = m_gmm_base_trainer.getGMMStats().n(i) / (m_gmm_base_trainer.getGMMStats().n(i) + m_relevance_factor); // - Update weights if requested // Equation 11 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000 - if (m_gmm_base_trainer->getUpdateWeights()) { + if (m_gmm_base_trainer.getUpdateWeights()) { // Calculate the maximum likelihood weights - m_cache_ml_weights = m_gmm_base_trainer->getGMMStats().n / static_cast<double>(m_gmm_base_trainer->getGMMStats().T); //cast req. for linux/32-bits & osx + m_cache_ml_weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx // Get the prior weights const blitz::Array<double,1>& prior_weights = m_prior_gmm->getWeights(); @@ -104,35 +114,35 @@ void bob::learn::misc::MAP_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm) // Update GMM parameters // - Update means if requested // Equation 12 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000 - if (m_gmm_base_trainer->getUpdateMeans()) { + if (m_gmm_base_trainer.getUpdateMeans()) { // Calculate new means for (size_t i=0; i<n_gaussians; ++i) { const blitz::Array<double,1>& prior_means = m_prior_gmm->getGaussian(i)->getMean(); blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean(); - if (m_gmm_base_trainer->getGMMStats().n(i) < m_gmm_base_trainer->getMeanVarUpdateResponsibilitiesThreshold()) { + if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) { means = prior_means; } else { // Use the maximum likelihood means - means = m_cache_alpha(i) * (m_gmm_base_trainer->getGMMStats().sumPx(i,blitz::Range::all()) / m_gmm_base_trainer->getGMMStats().n(i)) + (1-m_cache_alpha(i)) * prior_means; + means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats().sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i)) + (1-m_cache_alpha(i)) * prior_means; } } } // - Update variance if requested // Equation 13 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000 - if (m_gmm_base_trainer->getUpdateVariances()) { + if (m_gmm_base_trainer.getUpdateVariances()) { // Calculate new variances (equation 13) for (size_t i=0; i<n_gaussians; ++i) { const blitz::Array<double,1>& prior_means = m_prior_gmm->getGaussian(i)->getMean(); blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean(); const blitz::Array<double,1>& prior_variances = m_prior_gmm->getGaussian(i)->getVariance(); blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance(); - if (m_gmm_base_trainer->getGMMStats().n(i) < m_gmm_base_trainer->getMeanVarUpdateResponsibilitiesThreshold()) { + if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) { variances = (prior_variances + prior_means) - blitz::pow2(means); } else { - variances = m_cache_alpha(i) * m_gmm_base_trainer->getGMMStats().sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer->getGMMStats().n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means); + variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats().sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means); } gmm.getGaussian(i)->applyVarianceThresholds(); } diff --git a/bob/learn/misc/cpp/ML_GMMTrainer.cpp b/bob/learn/misc/cpp/ML_GMMTrainer.cpp index 84ecc5c..f08fb2f 100644 --- a/bob/learn/misc/cpp/ML_GMMTrainer.cpp +++ b/bob/learn/misc/cpp/ML_GMMTrainer.cpp @@ -8,8 +8,13 @@ #include <bob.learn.misc/ML_GMMTrainer.h> #include <algorithm> -bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer): - m_gmm_base_trainer(gmm_base_trainer) +bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer( + const bool update_means, + const bool update_variances, + const bool update_weights, + const double mean_var_update_responsibilities_threshold +): + m_gmm_base_trainer(update_means, update_variances, update_weights, mean_var_update_responsibilities_threshold) {} @@ -23,7 +28,7 @@ bob::learn::misc::ML_GMMTrainer::~ML_GMMTrainer() void bob::learn::misc::ML_GMMTrainer::initialize(bob::learn::misc::GMMMachine& gmm) { - m_gmm_base_trainer->initialize(gmm); + m_gmm_base_trainer.initialize(gmm); // Allocate cache size_t n_gaussians = gmm.getNGaussians(); @@ -38,24 +43,24 @@ void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm) // - Update weights if requested // Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006 - if (m_gmm_base_trainer->getUpdateWeights()) { + if (m_gmm_base_trainer.getUpdateWeights()) { blitz::Array<double,1>& weights = gmm.updateWeights(); - weights = m_gmm_base_trainer->getGMMStats().n / static_cast<double>(m_gmm_base_trainer->getGMMStats().T); //cast req. for linux/32-bits & osx + weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx // Recompute the log weights in the cache of the GMMMachine gmm.recomputeLogWeights(); } // Generate a thresholded version of m_ss.n for(size_t i=0; i<n_gaussians; ++i) - m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer->getGMMStats().n(i), m_gmm_base_trainer->getMeanVarUpdateResponsibilitiesThreshold()); + m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats().n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()); // Update GMM parameters using the sufficient statistics (m_ss) // - Update means if requested // Equation 9.24 of Bishop, "Pattern recognition and machine learning", 2006 - if (m_gmm_base_trainer->getUpdateMeans()) { + if (m_gmm_base_trainer.getUpdateMeans()) { for(size_t i=0; i<n_gaussians; ++i) { blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean(); - means = m_gmm_base_trainer->getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i); + means = m_gmm_base_trainer.getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i); } } @@ -64,11 +69,11 @@ void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm) // ...but we use the "computational formula for the variance", i.e. // var = 1/n * sum (P(x-mean)(x-mean)) // = 1/n * sum (Pxx) - mean^2 - if (m_gmm_base_trainer->getUpdateVariances()) { + if (m_gmm_base_trainer.getUpdateVariances()) { for(size_t i=0; i<n_gaussians; ++i) { const blitz::Array<double,1>& means = gmm.getGaussian(i)->getMean(); blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance(); - variances = m_gmm_base_trainer->getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means); + variances = m_gmm_base_trainer.getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means); gmm.getGaussian(i)->applyVarianceThresholds(); } } diff --git a/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h index 086fdc2..4015055 100644 --- a/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h @@ -30,9 +30,9 @@ class GMMBaseTrainer * @brief Default constructor */ GMMBaseTrainer(const bool update_means=true, - const bool update_variances=false, const bool update_weights=false, - const double mean_var_update_responsibilities_threshold = - std::numeric_limits<double>::epsilon()); + const bool update_variances=false, + const bool update_weights=false, + const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon()); /** * @brief Copy constructor @@ -47,7 +47,7 @@ class GMMBaseTrainer /** * @brief Initialization before the EM steps */ - virtual void initialize(bob::learn::misc::GMMMachine& gmm); + void initialize(bob::learn::misc::GMMMachine& gmm); /** * @brief Calculates and saves statistics across the dataset, @@ -58,14 +58,14 @@ class GMMBaseTrainer * The statistics, m_ss, will be used in the mStep() that follows. * Implements EMTrainer::eStep(double &) */ - virtual void eStep(bob::learn::misc::GMMMachine& gmm, + void eStep(bob::learn::misc::GMMMachine& gmm, const blitz::Array<double,2>& data); /** * @brief Computes the likelihood using current estimates of the latent * variables */ - virtual double computeLikelihood(bob::learn::misc::GMMMachine& gmm); + double computeLikelihood(bob::learn::misc::GMMMachine& gmm); /** diff --git a/bob/learn/misc/include/bob.learn.misc/MAP_GMMTrainer.h b/bob/learn/misc/include/bob.learn.misc/MAP_GMMTrainer.h index 42e4552..c6c7cf7 100644 --- a/bob/learn/misc/include/bob.learn.misc/MAP_GMMTrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/MAP_GMMTrainer.h @@ -26,7 +26,15 @@ class MAP_GMMTrainer /** * @brief Default constructor */ - MAP_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer, boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm, const bool reynolds_adaptation=false, const double relevance_factor=4, const double alpha=0.5); + MAP_GMMTrainer( + const bool update_means=true, + const bool update_variances=false, + const bool update_weights=false, + const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(), + const bool reynolds_adaptation=false, + const double relevance_factor=4, + const double alpha=0.5, + boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm = 0); /** * @brief Copy constructor @@ -41,7 +49,7 @@ class MAP_GMMTrainer /** * @brief Initialization */ - virtual void initialize(bob::learn::misc::GMMMachine& gmm); + void initialize(bob::learn::misc::GMMMachine& gmm); /** * @brief Assigns from a different MAP_GMMTrainer @@ -71,6 +79,21 @@ class MAP_GMMTrainer */ bool setPriorGMM(boost::shared_ptr<bob::learn::misc::GMMMachine> prior_gmm); + /** + * @brief Calculates and saves statistics across the dataset, + * and saves these as m_ss. Calculates the average + * log likelihood of the observations given the GMM, + * and returns this in average_log_likelihood. + * + * The statistics, m_ss, will be used in the mStep() that follows. + * Implements EMTrainer::eStep(double &) + */ + void eStep(bob::learn::misc::GMMMachine& gmm, + const blitz::Array<double,2>& data){ + m_gmm_base_trainer.eStep(gmm,data); + } + + /** * @brief Performs a maximum a posteriori (MAP) update of the GMM * parameters using the accumulated statistics in m_ss and the @@ -80,13 +103,12 @@ class MAP_GMMTrainer void mStep(bob::learn::misc::GMMMachine& gmm); /** - * @brief Use a Torch3-like adaptation rule rather than Reynolds'one - * In this case, alpha is a configuration variable rather than a function of the zeroth - * order statistics and a relevance factor (should be in range [0,1]) + * @brief Computes the likelihood using current estimates of the latent + * variables */ - //void setT3MAP(const double alpha) { m_T3_adaptation = true; m_T3_alpha = alpha; } - //void unsetT3MAP() { m_T3_adaptation = false; } - + double computeLikelihood(bob::learn::misc::GMMMachine& gmm){ + return m_gmm_base_trainer.computeLikelihood(gmm); + } bool getReynoldsAdaptation() {return m_reynolds_adaptation;} @@ -109,13 +131,6 @@ class MAP_GMMTrainer {m_alpha = alpha;} - boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> getGMMBaseTrainer() - {return m_gmm_base_trainer;} - - void setGMMBaseTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer) - {m_gmm_base_trainer = gmm_base_trainer;} - - protected: /** @@ -123,12 +138,10 @@ class MAP_GMMTrainer */ double m_relevance_factor; - /** Base Trainer for the MAP algorithm. Basically implements the e-step */ - boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> m_gmm_base_trainer; - + bob::learn::misc::GMMBaseTrainer m_gmm_base_trainer; /** * The GMM to use as a prior for MAP adaptation. diff --git a/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h b/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h index 09b3db1..13cda74 100644 --- a/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h @@ -27,7 +27,10 @@ class ML_GMMTrainer{ /** * @brief Default constructor */ - ML_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer); + ML_GMMTrainer(const bool update_means=true, + const bool update_variances=false, + const bool update_weights=false, + const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon()); /** * @brief Copy constructor @@ -42,14 +45,37 @@ class ML_GMMTrainer{ /** * @brief Initialisation before the EM steps */ - virtual void initialize(bob::learn::misc::GMMMachine& gmm); + void initialize(bob::learn::misc::GMMMachine& gmm); + + /** + * @brief Calculates and saves statistics across the dataset, + * and saves these as m_ss. Calculates the average + * log likelihood of the observations given the GMM, + * and returns this in average_log_likelihood. + * + * The statistics, m_ss, will be used in the mStep() that follows. + * Implements EMTrainer::eStep(double &) + */ + void eStep(bob::learn::misc::GMMMachine& gmm, + const blitz::Array<double,2>& data){ + m_gmm_base_trainer.eStep(gmm,data); + } /** * @brief Performs a maximum likelihood (ML) update of the GMM parameters * using the accumulated statistics in m_ss * Implements EMTrainer::mStep() */ - virtual void mStep(bob::learn::misc::GMMMachine& gmm); + void mStep(bob::learn::misc::GMMMachine& gmm); + + /** + * @brief Computes the likelihood using current estimates of the latent + * variables + */ + double computeLikelihood(bob::learn::misc::GMMMachine& gmm){ + return m_gmm_base_trainer.computeLikelihood(gmm); + } + /** * @brief Assigns from a different ML_GMMTrainer @@ -73,19 +99,12 @@ class ML_GMMTrainer{ const double a_epsilon=1e-8) const; - boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> getGMMBaseTrainer() - {return m_gmm_base_trainer;} - - void setGMMBaseTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer) - {m_gmm_base_trainer = gmm_base_trainer;} - - protected: /** Base Trainer for the MAP algorithm. Basically implements the e-step */ - boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> m_gmm_base_trainer; + bob::learn::misc::GMMBaseTrainer m_gmm_base_trainer; private: diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp index 9e6dc2c..10e1e8a 100644 --- a/bob/learn/misc/main.cpp +++ b/bob/learn/misc/main.cpp @@ -75,7 +75,7 @@ static PyObject* create_module (void) { if (!init_BobLearnMiscGMMMachine(module)) return 0; if (!init_BobLearnMiscKMeansMachine(module)) return 0; if (!init_BobLearnMiscKMeansTrainer(module)) return 0; - if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0; + //if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0; if (!init_BobLearnMiscMLGMMTrainer(module)) return 0; if (!init_BobLearnMiscMAPGMMTrainer(module)) return 0; diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h index 42f5f12..5be119c 100644 --- a/bob/learn/misc/main.h +++ b/bob/learn/misc/main.h @@ -26,7 +26,7 @@ #include <bob.learn.misc/KMeansMachine.h> #include <bob.learn.misc/KMeansTrainer.h> -#include <bob.learn.misc/GMMBaseTrainer.h> +//#include <bob.learn.misc/GMMBaseTrainer.h> #include <bob.learn.misc/ML_GMMTrainer.h> #include <bob.learn.misc/MAP_GMMTrainer.h> @@ -145,6 +145,7 @@ int PyBobLearnMiscKMeansTrainer_Check(PyObject* o); // GMMBaseTrainer +/* typedef struct { PyObject_HEAD boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> cxx; @@ -153,7 +154,7 @@ typedef struct { extern PyTypeObject PyBobLearnMiscGMMBaseTrainer_Type; bool init_BobLearnMiscGMMBaseTrainer(PyObject* module); int PyBobLearnMiscGMMBaseTrainer_Check(PyObject* o); - +*/ // ML_GMMTrainer typedef struct { diff --git a/bob/learn/misc/test_em.py b/bob/learn/misc/test_em.py index dffb86a..88070de 100644 --- a/bob/learn/misc/test_em.py +++ b/bob/learn/misc/test_em.py @@ -14,7 +14,7 @@ import bob.io.base from bob.io.base.test_utils import datafile from . import KMeansMachine, GMMMachine, KMeansTrainer, \ - GMMBaseTrainer, ML_GMMTrainer, MAP_GMMTrainer + ML_GMMTrainer, MAP_GMMTrainer #, MAP_GMMTrainer @@ -50,7 +50,7 @@ def test_gmm_ML_1(): ar = bob.io.base.load(datafile("faithful.torch3_f64.hdf5", __name__)) gmm = loadGMM() - ml_gmmtrainer = ML_GMMTrainer(GMMBaseTrainer(True, True, True)) + ml_gmmtrainer = ML_GMMTrainer(True, True, True) ml_gmmtrainer.train(gmm, ar) #config = bob.io.base.HDF5File(datafile('gmm_ML.hdf5", __name__), 'w') @@ -82,7 +82,7 @@ def test_gmm_ML_2(): prior = 0.001 max_iter_gmm = 25 accuracy = 0.00001 - ml_gmmtrainer = ML_GMMTrainer(GMMBaseTrainer(True, True, True, prior), converge_by_likelihood=True) + ml_gmmtrainer = ML_GMMTrainer(True, True, True, prior, converge_by_likelihood=True) ml_gmmtrainer.max_iterations = max_iter_gmm ml_gmmtrainer.convergence_threshold = accuracy @@ -97,8 +97,6 @@ def test_gmm_ML_2(): weightsML_ref = bob.io.base.load(datafile('weightsAfterML.hdf5', __name__)) - print sum(sum(gmm.means - meansML_ref)) - # Compare to current results assert equals(gmm.means, meansML_ref, 3e-3) assert equals(gmm.variances, variancesML_ref, 3e-3) @@ -115,7 +113,7 @@ def test_gmm_MAP_1(): gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__))) gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__))) - map_gmmtrainer = MAP_GMMTrainer(GMMBaseTrainer(True, False, False),gmmprior, relevance_factor=4.) + map_gmmtrainer = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, prior_gmm=gmmprior, relevance_factor=4.) #map_gmmtrainer.set_prior_gmm(gmmprior) map_gmmtrainer.train(gmm, ar) @@ -142,7 +140,7 @@ def test_gmm_MAP_2(): gmm.variances = variances gmm.weights = weights - map_adapt = MAP_GMMTrainer(GMMBaseTrainer(True, False, False, mean_var_update_responsibilities_threshold=0.),gmm, relevance_factor=4.) + map_adapt = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, mean_var_update_responsibilities_threshold=0.,prior_gmm=gmm, relevance_factor=4.) #map_adapt.set_prior_gmm(gmm) gmm_adapted = GMMMachine(2,50) @@ -186,7 +184,7 @@ def test_gmm_MAP_3(): max_iter_gmm = 1 accuracy = 0.00001 map_factor = 0.5 - map_gmmtrainer = MAP_GMMTrainer(GMMBaseTrainer(True, False, False, prior), prior_gmm, alpha=map_factor) + map_gmmtrainer = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, convergence_threshold=prior, prior_gmm=prior_gmm, alpha=map_factor) map_gmmtrainer.max_iterations = max_iter_gmm map_gmmtrainer.convergence_threshold = accuracy diff --git a/setup.py b/setup.py index aa58b9c..875085d 100644 --- a/setup.py +++ b/setup.py @@ -113,7 +113,7 @@ setup( "bob/learn/misc/gmm_machine.cpp", "bob/learn/misc/kmeans_machine.cpp", "bob/learn/misc/kmeans_trainer.cpp", - "bob/learn/misc/gmm_base_trainer.cpp", + "bob/learn/misc/ML_gmm_trainer.cpp", "bob/learn/misc/MAP_gmm_trainer.cpp", -- GitLab