From 1927b88be168d737ad5d94cc21be6005193ad5b0 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Fri, 23 Jan 2015 16:40:21 +0100 Subject: [PATCH] Mapped ML_GMMTrainer --- bob/learn/misc/ML_gmm_trainer.cpp | 304 ++++++++++++++++++ bob/learn/misc/__ML_gmm_trainer__.py | 81 +++++ bob/learn/misc/cpp/ML_GMMTrainer.cpp | 58 ++-- bob/learn/misc/gmm_base_trainer.cpp | 103 +++++- .../include/bob.learn.misc/GMMBaseTrainer.h | 23 ++ .../include/bob.learn.misc/ML_GMMTrainer.h | 30 +- bob/learn/misc/main.cpp | 4 +- bob/learn/misc/main.h | 14 + setup.py | 5 +- 9 files changed, 571 insertions(+), 51 deletions(-) create mode 100644 bob/learn/misc/ML_gmm_trainer.cpp create mode 100644 bob/learn/misc/__ML_gmm_trainer__.py diff --git a/bob/learn/misc/ML_gmm_trainer.cpp b/bob/learn/misc/ML_gmm_trainer.cpp new file mode 100644 index 0000000..bddc9de --- /dev/null +++ b/bob/learn/misc/ML_gmm_trainer.cpp @@ -0,0 +1,304 @@ +/** + * @author Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + * @date Web 22 Jan 16:45:00 2015 + * + * @brief Python API for bob::learn::em + * + * Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland + */ + +#include "main.h" + +/******************************************************************/ +/************ Constructor Section *********************************/ +/******************************************************************/ + +static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* converts PyObject to bool and returns false if object is NULL */ + +static auto ML_GMMTrainer_doc = bob::extension::ClassDoc( + BOB_EXT_MODULE_PREFIX ".ML_GMMTrainer", + "This class implements the maximum likelihood M-step of the expectation-maximisation algorithm for a GMM Machine." +).add_constructor( + bob::extension::FunctionDoc( + "__init__", + "Creates a ML_GMMTrainer", + "", + true + ) + .add_prototype("gmm_base_trainer","") + .add_prototype("other","") + .add_prototype("","") + + .add_parameter("gmm_base_trainer", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A set GMMBaseTrainer object.") + .add_parameter("other", ":py:class:`bob.learn.misc.ML_GMMTrainer`", "A ML_GMMTrainer object to be copied.") +); + + +static int PyBobLearnMiscMLGMMTrainer_init_copy(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = ML_GMMTrainer_doc.kwlist(1); + PyBobLearnMiscMLGMMTrainerObject* o; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscMLGMMTrainer_Type, &o)){ + ML_GMMTrainer_doc.print_usage(); + return -1; + } + + self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(*o->cxx)); + return 0; +} + + +static int PyBobLearnMiscMLGMMTrainer_init_base_trainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + + char** kwlist = ML_GMMTrainer_doc.kwlist(1); + PyBobLearnMiscGMMBaseTrainerObject* o; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMBaseTrainer_Type, &o)){ + ML_GMMTrainer_doc.print_usage(); + return -1; + } + + self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(o->cxx)); + return 0; +} + + + +static int PyBobLearnMiscMLGMMTrainer_init(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0); + + if (nargs==1){ //default initializer () + //Reading the input argument + PyObject* arg = 0; + if (PyTuple_Size(args)) + arg = PyTuple_GET_ITEM(args, 0); + else { + PyObject* tmp = PyDict_Values(kwargs); + auto tmp_ = make_safe(tmp); + arg = PyList_GET_ITEM(tmp, 0); + } + + // If the constructor input is GMMBaseTrainer object + if (PyBobLearnMiscGMMBaseTrainer_Check(arg)) + return PyBobLearnMiscMLGMMTrainer_init_base_trainer(self, args, kwargs); + else + return PyBobLearnMiscMLGMMTrainer_init_copy(self, args, kwargs); + } + else{ + PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 1, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs); + ML_GMMTrainer_doc.print_usage(); + return -1; + } + + BOB_CATCH_MEMBER("cannot create GMMBaseTrainer_init_bool", 0) + return 0; +} + + +static void PyBobLearnMiscMLGMMTrainer_delete(PyBobLearnMiscMLGMMTrainerObject* self) { + self->cxx.reset(); + Py_TYPE(self)->tp_free((PyObject*)self); +} + + +int PyBobLearnMiscMLGMMTrainer_Check(PyObject* o) { + return PyObject_IsInstance(o, reinterpret_cast<PyObject*>(&PyBobLearnMiscMLGMMTrainer_Type)); +} + + +static PyObject* PyBobLearnMiscMLGMMTrainer_RichCompare(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* other, int op) { + BOB_TRY + + if (!PyBobLearnMiscMLGMMTrainer_Check(other)) { + PyErr_Format(PyExc_TypeError, "cannot compare `%s' with `%s'", Py_TYPE(self)->tp_name, Py_TYPE(other)->tp_name); + return 0; + } + auto other_ = reinterpret_cast<PyBobLearnMiscMLGMMTrainerObject*>(other); + switch (op) { + case Py_EQ: + if (*self->cxx==*other_->cxx) Py_RETURN_TRUE; else Py_RETURN_FALSE; + case Py_NE: + if (*self->cxx==*other_->cxx) Py_RETURN_FALSE; else Py_RETURN_TRUE; + default: + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + BOB_CATCH_MEMBER("cannot compare ML_GMMTrainer objects", 0) +} + + +/******************************************************************/ +/************ Variables Section ***********************************/ +/******************************************************************/ + + +/***** gmm_base_trainer *****/ +static auto gmm_base_trainer = bob::extension::VariableDoc( + "gmm_base_trainer", + ":py:class:`bob.learn.misc.GMMBaseTrainer`", + "This class that implements the E-step of the expectation-maximisation algorithm.", + "" +); +PyObject* PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, void*){ + BOB_TRY + + boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer = self->cxx->getGMMBaseTrainer(); + + //Allocating the correspondent python object + PyBobLearnMiscGMMBaseTrainerObject* retval = + (PyBobLearnMiscGMMBaseTrainerObject*)PyBobLearnMiscGMMBaseTrainer_Type.tp_alloc(&PyBobLearnMiscGMMBaseTrainer_Type, 0); + + retval->cxx = gmm_base_trainer; + + return Py_BuildValue("O",retval); + BOB_CATCH_MEMBER("GMMBaseTrainer could not be read", 0) +} +int PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* value, void*){ + BOB_TRY + + if (!PyBobLearnMiscGMMBaseTrainer_Check(value)){ + PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMBaseTrainer`", Py_TYPE(self)->tp_name, gmm_base_trainer.name()); + return -1; + } + + PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer = 0; + PyArg_Parse(value, "O!", &PyBobLearnMiscGMMBaseTrainer_Type,&gmm_base_trainer); + + self->cxx->setGMMBaseTrainer(gmm_base_trainer->cxx); + + return 0; + BOB_CATCH_MEMBER("gmm_base_trainer could not be set", -1) +} + + +static PyGetSetDef PyBobLearnMiscMLGMMTrainer_getseters[] = { + { + gmm_base_trainer.name(), + (getter)PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer, + (setter)PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer, + gmm_base_trainer.doc(), + 0 + }, + {0} // Sentinel +}; + + +/******************************************************************/ +/************ Functions Section ***********************************/ +/******************************************************************/ + +/*** initialize ***/ +static auto initialize = bob::extension::FunctionDoc( + "initialize", + "Initialization before the EM steps", + "", + true +) +.add_prototype("gmm_machine") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object"); +static PyObject* PyBobLearnMiscMLGMMTrainer_initialize(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = initialize.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)){ + PyErr_Format(PyExc_RuntimeError, "%s.%s. Was not possible to read :py:class:`bob.learn.misc.GMMMachine`", Py_TYPE(self)->tp_name, initialize.name()); + Py_RETURN_NONE; + } + self->cxx->initialize(*gmm_machine->cxx); + BOB_CATCH_MEMBER("cannot perform the initialize method", 0) + + Py_RETURN_NONE; +} + + + +/*** mStep ***/ +static auto mStep = bob::extension::FunctionDoc( + "mStep", + "Performs a maximum likelihood (ML) update of the GMM parameters" + "using the accumulated statistics in :py:class:`bob.learn.misc.GMMBaseTrainer.m_ss`", + + "See Section 9.2.2 of Bishop, \"Pattern recognition and machine learning\", 2006", + + true +) +.add_prototype("gmm_machine") +.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object"); +static PyObject* PyBobLearnMiscMLGMMTrainer_mStep(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) { + BOB_TRY + + /* Parses input arguments in a single shot */ + char** kwlist = mStep.kwlist(0); + + PyBobLearnMiscGMMMachineObject* gmm_machine; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE; + + self->cxx->mStep(*gmm_machine->cxx); + + BOB_CATCH_MEMBER("cannot perform the mStep method", 0) + + Py_RETURN_NONE; +} + + + +static PyMethodDef PyBobLearnMiscMLGMMTrainer_methods[] = { + { + initialize.name(), + (PyCFunction)PyBobLearnMiscMLGMMTrainer_initialize, + METH_VARARGS|METH_KEYWORDS, + initialize.doc() + }, + { + mStep.name(), + (PyCFunction)PyBobLearnMiscMLGMMTrainer_mStep, + METH_VARARGS|METH_KEYWORDS, + mStep.doc() + }, + {0} /* Sentinel */ +}; + + +/******************************************************************/ +/************ Module Section **************************************/ +/******************************************************************/ + +// Define the Gaussian type struct; will be initialized later +PyTypeObject PyBobLearnMiscMLGMMTrainer_Type = { + PyVarObject_HEAD_INIT(0,0) + 0 +}; + +bool init_BobLearnMiscMLGMMTrainer(PyObject* module) +{ + // initialize the type struct + PyBobLearnMiscMLGMMTrainer_Type.tp_name = ML_GMMTrainer_doc.name(); + PyBobLearnMiscMLGMMTrainer_Type.tp_basicsize = sizeof(PyBobLearnMiscMLGMMTrainerObject); + PyBobLearnMiscMLGMMTrainer_Type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;//Enable the class inheritance + PyBobLearnMiscMLGMMTrainer_Type.tp_doc = ML_GMMTrainer_doc.doc(); + + // set the functions + PyBobLearnMiscMLGMMTrainer_Type.tp_new = PyType_GenericNew; + PyBobLearnMiscMLGMMTrainer_Type.tp_init = reinterpret_cast<initproc>(PyBobLearnMiscMLGMMTrainer_init); + PyBobLearnMiscMLGMMTrainer_Type.tp_dealloc = reinterpret_cast<destructor>(PyBobLearnMiscMLGMMTrainer_delete); + PyBobLearnMiscMLGMMTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscMLGMMTrainer_RichCompare); + PyBobLearnMiscMLGMMTrainer_Type.tp_methods = PyBobLearnMiscMLGMMTrainer_methods; + PyBobLearnMiscMLGMMTrainer_Type.tp_getset = PyBobLearnMiscMLGMMTrainer_getseters; + //PyBobLearnMiscMLGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMLGMMTrainer_compute_likelihood); + + + // check that everything is fine + if (PyType_Ready(&PyBobLearnMiscMLGMMTrainer_Type) < 0) return false; + + // add the type to the module + Py_INCREF(&PyBobLearnMiscMLGMMTrainer_Type); + return PyModule_AddObject(module, "_ML_GMMTrainer", (PyObject*)&PyBobLearnMiscMLGMMTrainer_Type) >= 0; +} + diff --git a/bob/learn/misc/__ML_gmm_trainer__.py b/bob/learn/misc/__ML_gmm_trainer__.py new file mode 100644 index 0000000..53c5a75 --- /dev/null +++ b/bob/learn/misc/__ML_gmm_trainer__.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> +# Mon Jan 22 18:29:10 2015 +# +# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland + +from ._library import _ML_GMMTrainer +import numpy + +# define the class +class ML_GMMTrainer(_ML_GMMTrainer): + + def __init__(self, gmm_base_trainer, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True): + """ + :py:class:bob.learn.misc.ML_GMMTrainer constructor + + Keyword Parameters: + gmm_base_trainer + The base trainer (:py:class:`bob.learn.misc.GMMBaseTrainer` + convergence_threshold + Convergence threshold + max_iterations + Number of maximum iterations + converge_by_likelihood + Tells whether we compute log_likelihood as a convergence criteria, or not + + """ + + _ML_GMMTrainer.__init__(self, gmm_base_trainer) + self.convergence_threshold = convergence_threshold + self.max_iterations = max_iterations + self.converge_by_likelihood = converge_by_likelihood + + + def train(self, gmm_machine, data): + """ + Train the :py:class:bob.learn.misc.GMMMachine using data + + Keyword Parameters: + gmm_machine + The :py:class:bob.learn.misc.GMMMachine class + data + The data to be trained + """ + + #Initialization + self.initialize(gmm_machine); + + #Do the Expectation-Maximization algorithm + average_output_previous = 0 + average_output = -numpy.inf; + + + #eStep + self.gmm_base_trainer.eStep(gmm_machine, data); + + if(self.converge_by_likelihood): + average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine); + + for i in range(self.max_iterations): + #saves average output from last iteration + average_output_previous = average_output; + + #mStep + self.mStep(gmm_machine); + + #eStep + self.gmm_base_trainer.eStep(gmm_machine, data); + + #Computes log likelihood if required + if(self.converge_by_likelihood): + average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine); + + #Terminates if converged (and likelihood computation is set) + if abs((average_output_previous - average_output)/average_output_previous) <= self.convergence_threshold: + break + + +# copy the documentation from the base class +__doc__ = _ML_GMMTrainer.__doc__ diff --git a/bob/learn/misc/cpp/ML_GMMTrainer.cpp b/bob/learn/misc/cpp/ML_GMMTrainer.cpp index ec949e0..84ecc5c 100644 --- a/bob/learn/misc/cpp/ML_GMMTrainer.cpp +++ b/bob/learn/misc/cpp/ML_GMMTrainer.cpp @@ -8,59 +8,54 @@ #include <bob.learn.misc/ML_GMMTrainer.h> #include <algorithm> -bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(const bool update_means, - const bool update_variances, const bool update_weights, - const double mean_var_update_responsibilities_threshold): - bob::learn::misc::GMMTrainer(update_means, update_variances, update_weights, - mean_var_update_responsibilities_threshold) -{ -} +bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer): + m_gmm_base_trainer(gmm_base_trainer) +{} + + bob::learn::misc::ML_GMMTrainer::ML_GMMTrainer(const bob::learn::misc::ML_GMMTrainer& b): - bob::learn::misc::GMMTrainer(b) -{ -} + m_gmm_base_trainer(b.m_gmm_base_trainer) +{} bob::learn::misc::ML_GMMTrainer::~ML_GMMTrainer() -{ -} +{} -void bob::learn::misc::ML_GMMTrainer::initialize(bob::learn::misc::GMMMachine& gmm, - const blitz::Array<double,2>& data) +void bob::learn::misc::ML_GMMTrainer::initialize(bob::learn::misc::GMMMachine& gmm) { - bob::learn::misc::GMMTrainer::initialize(gmm, data); + m_gmm_base_trainer->initialize(gmm); + // Allocate cache size_t n_gaussians = gmm.getNGaussians(); m_cache_ss_n_thresholded.resize(n_gaussians); } -void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm, - const blitz::Array<double,2>& data) +void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm) { // Read options and variables const size_t n_gaussians = gmm.getNGaussians(); // - Update weights if requested // Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006 - if (m_update_weights) { + if (m_gmm_base_trainer->getUpdateWeights()) { blitz::Array<double,1>& weights = gmm.updateWeights(); - weights = m_ss.n / static_cast<double>(m_ss.T); //cast req. for linux/32-bits & osx + weights = m_gmm_base_trainer->getGMMStats().n / static_cast<double>(m_gmm_base_trainer->getGMMStats().T); //cast req. for linux/32-bits & osx // Recompute the log weights in the cache of the GMMMachine gmm.recomputeLogWeights(); } // Generate a thresholded version of m_ss.n for(size_t i=0; i<n_gaussians; ++i) - m_cache_ss_n_thresholded(i) = std::max(m_ss.n(i), m_mean_var_update_responsibilities_threshold); + m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer->getGMMStats().n(i), m_gmm_base_trainer->getMeanVarUpdateResponsibilitiesThreshold()); // Update GMM parameters using the sufficient statistics (m_ss) // - Update means if requested // Equation 9.24 of Bishop, "Pattern recognition and machine learning", 2006 - if (m_update_means) { + if (m_gmm_base_trainer->getUpdateMeans()) { for(size_t i=0; i<n_gaussians; ++i) { - blitz::Array<double,1>& means = gmm.updateGaussian(i)->updateMean(); - means = m_ss.sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i); + blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean(); + means = m_gmm_base_trainer->getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i); } } @@ -69,12 +64,12 @@ void bob::learn::misc::ML_GMMTrainer::mStep(bob::learn::misc::GMMMachine& gmm, // ...but we use the "computational formula for the variance", i.e. // var = 1/n * sum (P(x-mean)(x-mean)) // = 1/n * sum (Pxx) - mean^2 - if (m_update_variances) { + if (m_gmm_base_trainer->getUpdateVariances()) { for(size_t i=0; i<n_gaussians; ++i) { const blitz::Array<double,1>& means = gmm.getGaussian(i)->getMean(); - blitz::Array<double,1>& variances = gmm.updateGaussian(i)->updateVariance(); - variances = m_ss.sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means); - gmm.updateGaussian(i)->applyVarianceThresholds(); + blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance(); + variances = m_gmm_base_trainer->getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means); + gmm.getGaussian(i)->applyVarianceThresholds(); } } } @@ -84,7 +79,7 @@ bob::learn::misc::ML_GMMTrainer& bob::learn::misc::ML_GMMTrainer::operator= { if (this != &other) { - bob::learn::misc::GMMTrainer::operator=(other); + m_gmm_base_trainer = other.m_gmm_base_trainer; m_cache_ss_n_thresholded.resize(other.m_cache_ss_n_thresholded.extent(0)); } return *this; @@ -93,7 +88,7 @@ bob::learn::misc::ML_GMMTrainer& bob::learn::misc::ML_GMMTrainer::operator= bool bob::learn::misc::ML_GMMTrainer::operator== (const bob::learn::misc::ML_GMMTrainer &other) const { - return bob::learn::misc::GMMTrainer::operator==(other); + return m_gmm_base_trainer == other.m_gmm_base_trainer; } bool bob::learn::misc::ML_GMMTrainer::operator!= @@ -102,10 +97,11 @@ bool bob::learn::misc::ML_GMMTrainer::operator!= return !(this->operator==(other)); } +/* bool bob::learn::misc::ML_GMMTrainer::is_similar_to (const bob::learn::misc::ML_GMMTrainer &other, const double r_epsilon, const double a_epsilon) const { - return bob::learn::misc::GMMTrainer::is_similar_to(other, r_epsilon, a_epsilon); + return m_gmm_base_trainer.is_similar_to(other, r_epsilon, a_epsilon); } - +*/ diff --git a/bob/learn/misc/gmm_base_trainer.cpp b/bob/learn/misc/gmm_base_trainer.cpp index 5f133af..0308f16 100644 --- a/bob/learn/misc/gmm_base_trainer.cpp +++ b/bob/learn/misc/gmm_base_trainer.cpp @@ -8,6 +8,7 @@ */ #include "main.h" +#include <boost/make_shared.hpp> /******************************************************************/ /************ Constructor Section *********************************/ @@ -59,7 +60,7 @@ static int PyBobLearnMiscGMMBaseTrainer_init_bool(PyBobLearnMiscGMMBaseTrainerOb PyObject* update_means = 0; PyObject* update_variances = 0; PyObject* update_weights = 0; - double mean_var_update_responsibilities_threshold = 0; //std::numeric_limits<double>::epsilon(); + double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(); if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!d", kwlist, &PyBool_Type, &update_means, &PyBool_Type, &update_variances, &PyBool_Type, &update_weights, &mean_var_update_responsibilities_threshold)){ @@ -150,10 +151,8 @@ static auto gmm_stats = bob::extension::VariableDoc( PyObject* PyBobLearnMiscGMMBaseTrainer_getGMMStats(PyBobLearnMiscGMMBaseTrainerObject* self, void*){ BOB_TRY - //bob::learn::misc::GMMStats stats = self->cxx->getGMMStats(); - //boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = boost::make_shared(stats); - boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = 0; - + bob::learn::misc::GMMStats stats = self->cxx->getGMMStats(); + boost::shared_ptr<bob::learn::misc::GMMStats> stats_shared = boost::make_shared<bob::learn::misc::GMMStats>(stats); //Allocating the correspondent python object PyBobLearnMiscGMMStatsObject* retval = @@ -164,6 +163,7 @@ PyObject* PyBobLearnMiscGMMBaseTrainer_getGMMStats(PyBobLearnMiscGMMBaseTrainerO return Py_BuildValue("O",retval); BOB_CATCH_MEMBER("GMMStats could not be read", 0) } +/* int PyBobLearnMiscGMMBaseTrainer_setGMMStats(PyBobLearnMiscGMMBaseTrainerObject* self, PyObject* value, void*){ BOB_TRY @@ -180,17 +180,106 @@ int PyBobLearnMiscGMMBaseTrainer_setGMMStats(PyBobLearnMiscGMMBaseTrainerObject* return 0; BOB_CATCH_MEMBER("gmm_stats could not be set", -1) } +*/ + + +/***** update_means *****/ +static auto update_means = bob::extension::VariableDoc( + "update_means", + "bool", + "Update means on each iteration", + "" +); +PyObject* PyBobLearnMiscGMMBaseTrainer_getUpdateMeans(PyBobLearnMiscGMMBaseTrainerObject* self, void*){ + BOB_TRY + return Py_BuildValue("O",self->cxx->getUpdateMeans()?Py_True:Py_False); + BOB_CATCH_MEMBER("update_means could not be read", 0) +} + +/***** update_variances *****/ +static auto update_variances = bob::extension::VariableDoc( + "update_variances", + "bool", + "Update variances on each iteration", + "" +); +PyObject* PyBobLearnMiscGMMBaseTrainer_getUpdateVariances(PyBobLearnMiscGMMBaseTrainerObject* self, void*){ + BOB_TRY + return Py_BuildValue("O",self->cxx->getUpdateVariances()?Py_True:Py_False); + BOB_CATCH_MEMBER("update_variances could not be read", 0) +} + + +/***** update_weights *****/ +static auto update_weights = bob::extension::VariableDoc( + "update_weights", + "bool", + "Update weights on each iteration", + "" +); +PyObject* PyBobLearnMiscGMMBaseTrainer_getUpdateWeights(PyBobLearnMiscGMMBaseTrainerObject* self, void*){ + BOB_TRY + return Py_BuildValue("O",self->cxx->getUpdateWeights()?Py_True:Py_False); + BOB_CATCH_MEMBER("update_weights could not be read", 0) +} + + + +/***** mean_var_update_responsibilities_threshold *****/ +static auto mean_var_update_responsibilities_threshold = bob::extension::VariableDoc( + "mean_var_update_responsibilities_threshold", + "bool", + "Threshold over the responsibilities of the Gaussians" + "Equations 9.24, 9.25 of Bishop, \"Pattern recognition and machine learning\", 2006" + "require a division by the responsibilities, which might be equal to zero" + "because of numerical issue. This threshold is used to avoid such divisions.", + "" +); +PyObject* PyBobLearnMiscGMMBaseTrainer_getMeanVarUpdateResponsibilitiesThreshold(PyBobLearnMiscGMMBaseTrainerObject* self, void*){ + BOB_TRY + return Py_BuildValue("d",self->cxx->getMeanVarUpdateResponsibilitiesThreshold()); + BOB_CATCH_MEMBER("update_weights could not be read", 0) +} + static PyGetSetDef PyBobLearnMiscGMMBaseTrainer_getseters[] = { + { + update_means.name(), + (getter)PyBobLearnMiscGMMBaseTrainer_getUpdateMeans, + 0, + update_means.doc(), + 0 + }, + { + update_variances.name(), + (getter)PyBobLearnMiscGMMBaseTrainer_getUpdateVariances, + 0, + update_variances.doc(), + 0 + }, + { + update_weights.name(), + (getter)PyBobLearnMiscGMMBaseTrainer_getUpdateWeights, + 0, + update_weights.doc(), + 0 + }, + { + mean_var_update_responsibilities_threshold.name(), + (getter)PyBobLearnMiscGMMBaseTrainer_getMeanVarUpdateResponsibilitiesThreshold, + 0, + mean_var_update_responsibilities_threshold.doc(), + 0 + }, { gmm_stats.name(), (getter)PyBobLearnMiscGMMBaseTrainer_getGMMStats, - (setter)PyBobLearnMiscGMMBaseTrainer_setGMMStats, + 0, //(setter)PyBobLearnMiscGMMBaseTrainer_setGMMStats, gmm_stats.doc(), 0 - }, + }, {0} // Sentinel }; diff --git a/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h index 1cb0788..086fdc2 100644 --- a/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/GMMBaseTrainer.h @@ -101,14 +101,37 @@ class GMMBaseTrainer * E-step */ void setGMMStats(const bob::learn::misc::GMMStats& stats); + + /** + * update means on each iteration + */ + bool getUpdateMeans() + {return m_update_means;} + + /** + * update variances on each iteration + */ + bool getUpdateVariances() + {return m_update_variances;} + + + bool getUpdateWeights() + {return m_update_weights;} + + + double getMeanVarUpdateResponsibilitiesThreshold() + {return m_mean_var_update_responsibilities_threshold;} + private: + /** * These are the sufficient statistics, calculated during the * E-step and used during the M-step */ bob::learn::misc::GMMStats m_ss; + /** * update means on each iteration */ diff --git a/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h b/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h index 0522f27..09b3db1 100644 --- a/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/ML_GMMTrainer.h @@ -11,7 +11,7 @@ #ifndef BOB_LEARN_MISC_ML_GMMTRAINER_H #define BOB_LEARN_MISC_ML_GMMTRAINER_H -#include <bob.learn.misc/GMMTrainer.h> +#include <bob.learn.misc/GMMBaseTrainer.h> #include <limits> namespace bob { namespace learn { namespace misc { @@ -22,15 +22,12 @@ namespace bob { namespace learn { namespace misc { * @details See Section 9.2.2 of Bishop, * "Pattern recognition and machine learning", 2006 */ -class ML_GMMTrainer: public GMMTrainer { +class ML_GMMTrainer{ public: /** * @brief Default constructor */ - ML_GMMTrainer(const bool update_means=true, - const bool update_variances=false, const bool update_weights=false, - const double mean_var_update_responsibilities_threshold = - std::numeric_limits<double>::epsilon()); + ML_GMMTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer); /** * @brief Copy constructor @@ -45,16 +42,14 @@ class ML_GMMTrainer: public GMMTrainer { /** * @brief Initialisation before the EM steps */ - virtual void initialize(bob::learn::misc::GMMMachine& gmm, - const blitz::Array<double,2>& data); + virtual void initialize(bob::learn::misc::GMMMachine& gmm); /** * @brief Performs a maximum likelihood (ML) update of the GMM parameters * using the accumulated statistics in m_ss * Implements EMTrainer::mStep() */ - virtual void mStep(bob::learn::misc::GMMMachine& gmm, - const blitz::Array<double,2>& data); + virtual void mStep(bob::learn::misc::GMMMachine& gmm); /** * @brief Assigns from a different ML_GMMTrainer @@ -76,6 +71,21 @@ class ML_GMMTrainer: public GMMTrainer { */ bool is_similar_to(const ML_GMMTrainer& b, const double r_epsilon=1e-5, const double a_epsilon=1e-8) const; + + + boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> getGMMBaseTrainer() + {return m_gmm_base_trainer;} + + void setGMMBaseTrainer(boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer) + {m_gmm_base_trainer = gmm_base_trainer;} + + + protected: + + /** + Base Trainer for the MAP algorithm. Basically implements the e-step + */ + boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> m_gmm_base_trainer; private: diff --git a/bob/learn/misc/main.cpp b/bob/learn/misc/main.cpp index 0431832..404ff64 100644 --- a/bob/learn/misc/main.cpp +++ b/bob/learn/misc/main.cpp @@ -45,7 +45,9 @@ static PyObject* create_module (void) { if (!init_BobLearnMiscGMMMachine(module)) return 0; if (!init_BobLearnMiscKMeansMachine(module)) return 0; if (!init_BobLearnMiscKMeansTrainer(module)) return 0; - if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0; + if (!init_BobLearnMiscGMMBaseTrainer(module)) return 0; + if (!init_BobLearnMiscMLGMMTrainer(module)) return 0; + static void* PyBobLearnMisc_API[PyBobLearnMisc_API_pointers]; diff --git a/bob/learn/misc/main.h b/bob/learn/misc/main.h index cffea0e..6dd9bca 100644 --- a/bob/learn/misc/main.h +++ b/bob/learn/misc/main.h @@ -24,6 +24,7 @@ #include <bob.learn.misc/KMeansTrainer.h> #include <bob.learn.misc/GMMBaseTrainer.h> +#include <bob.learn.misc/ML_GMMTrainer.h> #if PY_VERSION_HEX >= 0x03000000 @@ -131,4 +132,17 @@ int PyBobLearnMiscGMMBaseTrainer_Check(PyObject* o); +// ML_GMMTrainer +typedef struct { + PyObject_HEAD + boost::shared_ptr<bob::learn::misc::ML_GMMTrainer> cxx; +} PyBobLearnMiscMLGMMTrainerObject; + +extern PyTypeObject PyBobLearnMiscMLGMMTrainer_Type; +bool init_BobLearnMiscMLGMMTrainer(PyObject* module); +int PyBobLearnMiscMLGMMTrainer_Check(PyObject* o); + + + + #endif // BOB_LEARN_EM_MAIN_H diff --git a/setup.py b/setup.py index 9bc5e23..aab0731 100644 --- a/setup.py +++ b/setup.py @@ -69,8 +69,8 @@ setup( #"bob/learn/misc/cpp/IVectorTrainer.cpp", #"bob/learn/misc/cpp/JFATrainer.cpp", "bob/learn/misc/cpp/KMeansTrainer.cpp", - #"bob/learn/misc/cpp/MAP_GMMTrainer.cpp", - #"bob/learn/misc/cpp/ML_GMMTrainer.cpp", + "bob/learn/misc/cpp/MAP_GMMTrainer.cpp", + "bob/learn/misc/cpp/ML_GMMTrainer.cpp", #"bob/learn/misc/cpp/PLDATrainer.cpp", ], bob_packages = bob_packages, @@ -106,6 +106,7 @@ setup( "bob/learn/misc/kmeans_machine.cpp", "bob/learn/misc/kmeans_trainer.cpp", "bob/learn/misc/gmm_base_trainer.cpp", + "bob/learn/misc/ML_gmm_trainer.cpp", "bob/learn/misc/main.cpp", ], -- GitLab