Commit 3aad63f1 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Reorganized the MAP and ML trainers

parent d23601d1
......@@ -13,6 +13,7 @@
/************ Constructor Section *********************************/
/******************************************************************/
static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* converts PyObject to bool and returns false if object is NULL */
static auto MAP_GMMTrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".MAP_GMMTrainer",
......@@ -24,18 +25,22 @@ static auto MAP_GMMTrainer_doc = bob::extension::ClassDoc(
"",
true
)
.add_prototype("gmm_base_trainer,prior_gmm,relevance_factor","")
.add_prototype("gmm_base_trainer,prior_gmm,alpha","")
.add_prototype("prior_gmm,relevance_factor, update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
.add_prototype("prior_gmm,alpha, update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
.add_prototype("other","")
.add_prototype("","")
.add_parameter("gmm_base_trainer", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A GMMBaseTrainer object.")
.add_parameter("prior_gmm", ":py:class:`bob.learn.misc.GMMMachine`", "The prior GMM to be adapted (Universal Backgroud Model UBM).")
.add_parameter("reynolds_adaptation", "bool", "Will use the Reynolds adaptation procedure? See Eq (14) from [Reynolds2000]_")
.add_parameter("relevance_factor", "double", "If set the reynolds_adaptation parameters, will apply the Reynolds Adaptation procedure. See Eq (14) from [Reynolds2000]_")
.add_parameter("alpha", "double", "Set directly the alpha parameter (Eq (14) from [Reynolds2000]_), ignoring zeroth order statistics as a weighting factor.")
.add_parameter("update_means", "bool", "Update means on each iteration")
.add_parameter("update_variances", "bool", "Update variances on each iteration")
.add_parameter("update_weights", "bool", "Update weights on each iteration")
.add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.")
.add_parameter("other", ":py:class:`bob.learn.misc.MAP_GMMTrainer`", "A MAP_GMMTrainer object to be copied.")
);
......@@ -59,43 +64,54 @@ static int PyBobLearnMiscMAPGMMTrainer_init_base_trainer(PyBobLearnMiscMAPGMMTra
char** kwlist1 = MAP_GMMTrainer_doc.kwlist(0);
char** kwlist2 = MAP_GMMTrainer_doc.kwlist(1);
PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer;
PyBobLearnMiscGMMMachineObject* gmm_machine;
bool reynolds_adaptation = false;
double alpha = 0.5;
double relevance_factor = 4.0;
double aux = 0;
PyObject* keyword_relevance_factor = Py_BuildValue("s", kwlist1[2]);
PyObject* keyword_alpha = Py_BuildValue("s", kwlist2[2]);
PyObject* update_means = 0;
PyObject* update_variances = 0;
PyObject* update_weights = 0;
double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon();
PyObject* keyword_relevance_factor = Py_BuildValue("s", kwlist1[1]);
PyObject* keyword_alpha = Py_BuildValue("s", kwlist2[1]);
//Here we have to select which keyword argument to read
if (kwargs && PyDict_Contains(kwargs, keyword_relevance_factor) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|d", kwlist1,
&PyBobLearnMiscGMMBaseTrainer_Type, &gmm_base_trainer,
if (kwargs && PyDict_Contains(kwargs, keyword_relevance_factor) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!dO!|O!O!d", kwlist1,
&PyBobLearnMiscGMMMachine_Type, &gmm_machine,
&aux)))
&aux,
&PyBool_Type, &update_means,
&PyBool_Type, &update_variances,
&PyBool_Type, &update_weights,
&mean_var_update_responsibilities_threshold)))
reynolds_adaptation = true;
else if (kwargs && PyDict_Contains(kwargs, keyword_alpha) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|d", kwlist2,
&PyBobLearnMiscGMMBaseTrainer_Type, &gmm_base_trainer,
else if (kwargs && PyDict_Contains(kwargs, keyword_alpha) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!dO!|O!O!d", kwlist2,
&PyBobLearnMiscGMMMachine_Type, &gmm_machine,
&aux)))
&aux,
&PyBool_Type, &update_means,
&PyBool_Type, &update_variances,
&PyBool_Type, &update_weights,
&mean_var_update_responsibilities_threshold)))
reynolds_adaptation = false;
else{
PyErr_Format(PyExc_RuntimeError, "%s. The third argument must be a keyword argument.", Py_TYPE(self)->tp_name);
PyErr_Format(PyExc_RuntimeError, "%s. The second argument must be a keyword argument.", Py_TYPE(self)->tp_name);
MAP_GMMTrainer_doc.print_usage();
return -1;
}
if (reynolds_adaptation)
relevance_factor = aux;
else
alpha = aux;
self->cxx.reset(new bob::learn::misc::MAP_GMMTrainer(gmm_base_trainer->cxx, gmm_machine->cxx, reynolds_adaptation,relevance_factor, alpha));
self->cxx.reset(new bob::learn::misc::MAP_GMMTrainer(f(update_means), f(update_variances), f(update_weights),
mean_var_update_responsibilities_threshold,
reynolds_adaptation,relevance_factor, alpha, gmm_machine->cxx));
return 0;
}
......@@ -151,47 +167,6 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_RichCompare(PyBobLearnMiscMAPGMMTra
/************ Variables Section ***********************************/
/******************************************************************/
/***** gmm_base_trainer *****/
static auto gmm_base_trainer = bob::extension::VariableDoc(
"gmm_base_trainer",
":py:class:`bob.learn.misc.GMMBaseTrainer`",
"This class that implements the E-step of the expectation-maximisation algorithm.",
""
);
PyObject* PyBobLearnMiscMAPGMMTrainer_getGMMBaseTrainer(PyBobLearnMiscMAPGMMTrainerObject* self, void*){
BOB_TRY
boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer = self->cxx->getGMMBaseTrainer();
//Allocating the correspondent python object
PyBobLearnMiscGMMBaseTrainerObject* retval =
(PyBobLearnMiscGMMBaseTrainerObject*)PyBobLearnMiscGMMBaseTrainer_Type.tp_alloc(&PyBobLearnMiscGMMBaseTrainer_Type, 0);
retval->cxx = gmm_base_trainer;
return Py_BuildValue("O",retval);
BOB_CATCH_MEMBER("GMMBaseTrainer could not be read", 0)
}
int PyBobLearnMiscMAPGMMTrainer_setGMMBaseTrainer(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* value, void*){
BOB_TRY
if (!PyBobLearnMiscGMMBaseTrainer_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMBaseTrainer`", Py_TYPE(self)->tp_name, gmm_base_trainer.name());
return -1;
}
PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer = 0;
PyArg_Parse(value, "O!", &PyBobLearnMiscGMMBaseTrainer_Type,&gmm_base_trainer);
self->cxx->setGMMBaseTrainer(gmm_base_trainer->cxx);
return 0;
BOB_CATCH_MEMBER("gmm_base_trainer could not be set", -1)
}
/***** relevance_factor *****/
static auto relevance_factor = bob::extension::VariableDoc(
"relevance_factor",
......@@ -246,13 +221,6 @@ int PyBobLearnMiscMAPGMMTrainer_setAlpha(PyBobLearnMiscMAPGMMTrainerObject* self
static PyGetSetDef PyBobLearnMiscMAPGMMTrainer_getseters[] = {
{
gmm_base_trainer.name(),
(getter)PyBobLearnMiscMAPGMMTrainer_getGMMBaseTrainer,
(setter)PyBobLearnMiscMAPGMMTrainer_setGMMBaseTrainer,
gmm_base_trainer.doc(),
0
},
{
alpha.name(),
(getter)PyBobLearnMiscMAPGMMTrainer_getAlpha,
......@@ -304,6 +272,40 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_initialize(PyBobLearnMiscMAPGMMTrai
}
/*** eStep ***/
static auto eStep = bob::extension::FunctionDoc(
"eStep",
"Calculates and saves statistics across the dataset,"
"and saves these as m_ss. ",
"Calculates the average log likelihood of the observations given the GMM,"
"and returns this in average_log_likelihood."
"The statistics, m_ss, will be used in the mStep() that follows.",
true
)
.add_prototype("gmm_machine,data")
.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object")
.add_parameter("data", "array_like <float, 2D>", "Input data");
static PyObject* PyBobLearnMiscMAPGMMTrainer_eStep(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
/* Parses input arguments in a single shot */
char** kwlist = eStep.kwlist(0);
PyBobLearnMiscGMMMachineObject* gmm_machine;
PyBlitzArrayObject* data = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine,
&PyBlitzArray_Converter, &data)) Py_RETURN_NONE;
auto data_ = make_safe(data);
self->cxx->eStep(*gmm_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data));
BOB_CATCH_MEMBER("cannot perform the eStep method", 0)
Py_RETURN_NONE;
}
/*** mStep ***/
static auto mStep = bob::extension::FunctionDoc(
......@@ -335,6 +337,31 @@ static PyObject* PyBobLearnMiscMAPGMMTrainer_mStep(PyBobLearnMiscMAPGMMTrainerOb
}
/*** computeLikelihood ***/
static auto compute_likelihood = bob::extension::FunctionDoc(
"compute_likelihood",
"This functions returns the average min (Square Euclidean) distance (average distance to the closest mean)",
0,
true
)
.add_prototype("gmm_machine")
.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object");
static PyObject* PyBobLearnMiscMAPGMMTrainer_compute_likelihood(PyBobLearnMiscMAPGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
/* Parses input arguments in a single shot */
char** kwlist = compute_likelihood.kwlist(0);
PyBobLearnMiscGMMMachineObject* gmm_machine;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE;
double value = self->cxx->computeLikelihood(*gmm_machine->cxx);
return Py_BuildValue("d", value);
BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0)
}
static PyMethodDef PyBobLearnMiscMAPGMMTrainer_methods[] = {
{
......@@ -343,13 +370,25 @@ static PyMethodDef PyBobLearnMiscMAPGMMTrainer_methods[] = {
METH_VARARGS|METH_KEYWORDS,
initialize.doc()
},
{
eStep.name(),
(PyCFunction)PyBobLearnMiscMAPGMMTrainer_eStep,
METH_VARARGS|METH_KEYWORDS,
eStep.doc()
},
{
mStep.name(),
(PyCFunction)PyBobLearnMiscMAPGMMTrainer_mStep,
METH_VARARGS|METH_KEYWORDS,
mStep.doc()
},
{
compute_likelihood.name(),
(PyCFunction)PyBobLearnMiscMAPGMMTrainer_compute_likelihood,
METH_VARARGS|METH_KEYWORDS,
compute_likelihood.doc()
},
{0} /* Sentinel */
};
......@@ -379,7 +418,7 @@ bool init_BobLearnMiscMAPGMMTrainer(PyObject* module)
PyBobLearnMiscMAPGMMTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscMAPGMMTrainer_RichCompare);
PyBobLearnMiscMAPGMMTrainer_Type.tp_methods = PyBobLearnMiscMAPGMMTrainer_methods;
PyBobLearnMiscMAPGMMTrainer_Type.tp_getset = PyBobLearnMiscMAPGMMTrainer_getseters;
//PyBobLearnMiscMAPGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMAPGMMTrainer_compute_likelihood);
PyBobLearnMiscMAPGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMAPGMMTrainer_compute_likelihood);
// check that everything is fine
......
......@@ -13,6 +13,8 @@
/************ Constructor Section *********************************/
/******************************************************************/
static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* converts PyObject to bool and returns false if object is NULL */
static auto ML_GMMTrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".ML_GMMTrainer",
"This class implements the maximum likelihood M-step of the expectation-maximisation algorithm for a GMM Machine."
......@@ -23,11 +25,16 @@ static auto ML_GMMTrainer_doc = bob::extension::ClassDoc(
"",
true
)
.add_prototype("gmm_base_trainer","")
.add_prototype("update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
.add_prototype("other","")
.add_prototype("","")
.add_parameter("gmm_base_trainer", ":py:class:`bob.learn.misc.GMMBaseTrainer`", "A set GMMBaseTrainer object.")
.add_parameter("update_means", "bool", "Update means on each iteration")
.add_parameter("update_variances", "bool", "Update variances on each iteration")
.add_parameter("update_weights", "bool", "Update weights on each iteration")
.add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.")
.add_parameter("other", ":py:class:`bob.learn.misc.ML_GMMTrainer`", "A ML_GMMTrainer object to be copied.")
);
......@@ -48,15 +55,24 @@ static int PyBobLearnMiscMLGMMTrainer_init_copy(PyBobLearnMiscMLGMMTrainerObject
static int PyBobLearnMiscMLGMMTrainer_init_base_trainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
char** kwlist = ML_GMMTrainer_doc.kwlist(1);
PyBobLearnMiscGMMBaseTrainerObject* o;
char** kwlist = ML_GMMTrainer_doc.kwlist(0);
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMBaseTrainer_Type, &o)){
PyObject* update_means = 0;
PyObject* update_variances = 0;
PyObject* update_weights = 0;
double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon();
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!d", kwlist,
&PyBool_Type, &update_means,
&PyBool_Type, &update_variances,
&PyBool_Type, &update_weights,
&mean_var_update_responsibilities_threshold)){
ML_GMMTrainer_doc.print_usage();
return -1;
}
self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(o->cxx));
self->cxx.reset(new bob::learn::misc::ML_GMMTrainer(f(update_means), f(update_variances), f(update_weights),
mean_var_update_responsibilities_threshold));
return 0;
}
......@@ -65,31 +81,24 @@ static int PyBobLearnMiscMLGMMTrainer_init_base_trainer(PyBobLearnMiscMLGMMTrain
static int PyBobLearnMiscMLGMMTrainer_init(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
int nargs = (args?PyTuple_Size(args):0) + (kwargs?PyDict_Size(kwargs):0);
if (nargs==1){ //default initializer ()
//Reading the input argument
PyObject* arg = 0;
if (PyTuple_Size(args))
arg = PyTuple_GET_ITEM(args, 0);
else {
PyObject* tmp = PyDict_Values(kwargs);
auto tmp_ = make_safe(tmp);
arg = PyList_GET_ITEM(tmp, 0);
}
// If the constructor input is GMMBaseTrainer object
if (PyBobLearnMiscGMMBaseTrainer_Check(arg))
return PyBobLearnMiscMLGMMTrainer_init_base_trainer(self, args, kwargs);
else
return PyBobLearnMiscMLGMMTrainer_init_copy(self, args, kwargs);
}
else{
PyErr_Format(PyExc_RuntimeError, "number of arguments mismatch - %s requires 1, but you provided %d (see help)", Py_TYPE(self)->tp_name, nargs);
ML_GMMTrainer_doc.print_usage();
return -1;
//Reading the input argument
PyObject* arg = 0;
if (PyTuple_Size(args))
arg = PyTuple_GET_ITEM(args, 0);
else {
PyObject* tmp = PyDict_Values(kwargs);
auto tmp_ = make_safe(tmp);
arg = PyList_GET_ITEM(tmp, 0);
}
// If the constructor input is GMMBaseTrainer object
if (PyBobLearnMiscMLGMMTrainer_Check(arg))
return PyBobLearnMiscMLGMMTrainer_init_copy(self, args, kwargs);
else
return PyBobLearnMiscMLGMMTrainer_init_base_trainer(self, args, kwargs);
BOB_CATCH_MEMBER("cannot create GMMBaseTrainer_init_bool", 0)
return 0;
}
......@@ -131,54 +140,7 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_RichCompare(PyBobLearnMiscMLGMMTrain
/************ Variables Section ***********************************/
/******************************************************************/
/***** gmm_base_trainer *****/
static auto gmm_base_trainer = bob::extension::VariableDoc(
"gmm_base_trainer",
":py:class:`bob.learn.misc.GMMBaseTrainer`",
"This class that implements the E-step of the expectation-maximisation algorithm.",
""
);
PyObject* PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, void*){
BOB_TRY
boost::shared_ptr<bob::learn::misc::GMMBaseTrainer> gmm_base_trainer = self->cxx->getGMMBaseTrainer();
//Allocating the correspondent python object
PyBobLearnMiscGMMBaseTrainerObject* retval =
(PyBobLearnMiscGMMBaseTrainerObject*)PyBobLearnMiscGMMBaseTrainer_Type.tp_alloc(&PyBobLearnMiscGMMBaseTrainer_Type, 0);
retval->cxx = gmm_base_trainer;
return Py_BuildValue("O",retval);
BOB_CATCH_MEMBER("GMMBaseTrainer could not be read", 0)
}
int PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* value, void*){
BOB_TRY
if (!PyBobLearnMiscGMMBaseTrainer_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a :py:class:`bob.learn.misc.GMMBaseTrainer`", Py_TYPE(self)->tp_name, gmm_base_trainer.name());
return -1;
}
PyBobLearnMiscGMMBaseTrainerObject* gmm_base_trainer = 0;
PyArg_Parse(value, "O!", &PyBobLearnMiscGMMBaseTrainer_Type,&gmm_base_trainer);
self->cxx->setGMMBaseTrainer(gmm_base_trainer->cxx);
return 0;
BOB_CATCH_MEMBER("gmm_base_trainer could not be set", -1)
}
static PyGetSetDef PyBobLearnMiscMLGMMTrainer_getseters[] = {
{
gmm_base_trainer.name(),
(getter)PyBobLearnMiscMLGMMTrainer_getGMMBaseTrainer,
(setter)PyBobLearnMiscMLGMMTrainer_setGMMBaseTrainer,
gmm_base_trainer.doc(),
0
},
{0} // Sentinel
};
......@@ -215,6 +177,40 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_initialize(PyBobLearnMiscMLGMMTraine
}
/*** eStep ***/
static auto eStep = bob::extension::FunctionDoc(
"eStep",
"Calculates and saves statistics across the dataset,"
"and saves these as m_ss. ",
"Calculates the average log likelihood of the observations given the GMM,"
"and returns this in average_log_likelihood."
"The statistics, m_ss, will be used in the mStep() that follows.",
true
)
.add_prototype("gmm_machine,data")
.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object")
.add_parameter("data", "array_like <float, 2D>", "Input data");
static PyObject* PyBobLearnMiscMLGMMTrainer_eStep(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
/* Parses input arguments in a single shot */
char** kwlist = eStep.kwlist(0);
PyBobLearnMiscGMMMachineObject* gmm_machine;
PyBlitzArrayObject* data = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O&", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine,
&PyBlitzArray_Converter, &data)) Py_RETURN_NONE;
auto data_ = make_safe(data);
self->cxx->eStep(*gmm_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data));
BOB_CATCH_MEMBER("cannot perform the eStep method", 0)
Py_RETURN_NONE;
}
/*** mStep ***/
static auto mStep = bob::extension::FunctionDoc(
......@@ -246,6 +242,31 @@ static PyObject* PyBobLearnMiscMLGMMTrainer_mStep(PyBobLearnMiscMLGMMTrainerObje
}
/*** computeLikelihood ***/
static auto compute_likelihood = bob::extension::FunctionDoc(
"compute_likelihood",
"This functions returns the average min (Square Euclidean) distance (average distance to the closest mean)",
0,
true
)
.add_prototype("gmm_machine")
.add_parameter("gmm_machine", ":py:class:`bob.learn.misc.GMMMachine`", "GMMMachine Object");
static PyObject* PyBobLearnMiscMLGMMTrainer_compute_likelihood(PyBobLearnMiscMLGMMTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
/* Parses input arguments in a single shot */
char** kwlist = compute_likelihood.kwlist(0);
PyBobLearnMiscGMMMachineObject* gmm_machine;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnMiscGMMMachine_Type, &gmm_machine)) Py_RETURN_NONE;
double value = self->cxx->computeLikelihood(*gmm_machine->cxx);
return Py_BuildValue("d", value);
BOB_CATCH_MEMBER("cannot perform the computeLikelihood method", 0)
}
static PyMethodDef PyBobLearnMiscMLGMMTrainer_methods[] = {
{
......@@ -254,12 +275,24 @@ static PyMethodDef PyBobLearnMiscMLGMMTrainer_methods[] = {
METH_VARARGS|METH_KEYWORDS,
initialize.doc()
},
{
eStep.name(),
(PyCFunction)PyBobLearnMiscMLGMMTrainer_eStep,
METH_VARARGS|METH_KEYWORDS,
eStep.doc()
},
{
mStep.name(),
(PyCFunction)PyBobLearnMiscMLGMMTrainer_mStep,
METH_VARARGS|METH_KEYWORDS,
mStep.doc()
},
{
compute_likelihood.name(),
(PyCFunction)PyBobLearnMiscMLGMMTrainer_compute_likelihood,
METH_VARARGS|METH_KEYWORDS,
compute_likelihood.doc()
},
{0} /* Sentinel */
};
......@@ -289,7 +322,7 @@ bool init_BobLearnMiscMLGMMTrainer(PyObject* module)
PyBobLearnMiscMLGMMTrainer_Type.tp_richcompare = reinterpret_cast<richcmpfunc>(PyBobLearnMiscMLGMMTrainer_RichCompare);
PyBobLearnMiscMLGMMTrainer_Type.tp_methods = PyBobLearnMiscMLGMMTrainer_methods;
PyBobLearnMiscMLGMMTrainer_Type.tp_getset = PyBobLearnMiscMLGMMTrainer_getseters;
//PyBobLearnMiscMLGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMLGMMTrainer_compute_likelihood);
PyBobLearnMiscMLGMMTrainer_Type.tp_call = reinterpret_cast<ternaryfunc>(PyBobLearnMiscMLGMMTrainer_compute_likelihood);
// check that everything is fine
......
......@@ -11,13 +11,17 @@ import numpy
# define the class
class MAP_GMMTrainer(_MAP_GMMTrainer):
def __init__(self, gmm_base_trainer, prior_gmm, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True, **kwargs):
def __init__(self, prior_gmm, update_means=True, update_variances=False, update_weights=False, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True, **kwargs):
"""
:py:class:`bob.learn.misc.MAP_GMMTrainer` constructor
Keyword Parameters:
gmm_base_trainer
The base trainer (:py:class:`bob.learn.misc.GMMBaseTrainer`)
update_means
update_variances
update_weights
prior_gmm
A :py:class:`bob.learn.misc.GMMMachine` to be adapted
convergence_threshold
......@@ -34,10 +38,10 @@ class MAP_GMMTrainer(_MAP_GMMTrainer):
if kwargs.get('alpha')!=None:
alpha = kwargs.get('alpha')
_MAP_GMMTrainer.__init__(self, gmm_base_trainer, prior_gmm, alpha=alpha)
_MAP_GMMTrainer.__init__(self, prior_gmm,alpha=alpha, update_means=update_means, update_variances=update_variances,update_weights=update_weights)
else:
relevance_factor = kwargs.get('relevance_factor')
_MAP_GMMTrainer.__init__(self, gmm_base_trainer, prior_gmm, relevance_factor=relevance_factor)
_MAP_GMMTrainer.__init__(self, prior_gmm, relevance_factor=relevance_factor, update_means=update_means, update_variances=update_variances,update_weights=update_weights)
self.convergence_threshold = convergence_threshold
self.max_iterations = max_iterations
......@@ -67,10 +71,10 @@ class MAP_GMMTrainer(_MAP_GMMTrainer):
#eStep
self.gmm_base_trainer.eStep(gmm_machine, data);
self.eStep(gmm_machine, data);
if(self.converge_by_likelihood):
average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine);
average_output = self.compute_likelihood(gmm_machine);
for i in range(self.max_iterations):
#saves average output from last iteration
......@@ -80,11 +84,11 @@ class MAP_GMMTrainer(_MAP_GMMTrainer):
self.mStep(gmm_machine);
#eStep
self.gmm_base_trainer.eStep(gmm_machine, data);
self.eStep(gmm_machine, data);
#Computes log likelihood if required
if(self.converge_by_likelihood):
average_output = self.gmm_base_trainer.compute_likelihood(gmm_machine);
average_output = self.compute_likelihood(gmm_machine);
#Terminates if converged (and likelihood computation is set)
if abs((average_output_previous - average_output)/average_output_previous) <= self.convergence_threshold:
......
......@@ -11,13 +11,17 @@ import numpy
# define the class
class ML_GMMTrainer(_ML_GMMTrainer):
def __init__(self, gmm_base_trainer, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True):
def __init__(self, update_means=True, update_variances=False, update_weights=False, convergence_threshold=0.001, max_iterations=10, converge_by_likelihood=True):
"""
:py:class:bob.learn.misc.ML_GMMTrainer constructor
Keyword Parameters:
gmm_base_trainer
The base trainer (:py:class:`bob.learn.misc.GMMBaseTrainer`