Skip to content
Snippets Groups Projects
Commit b235305b authored by Manuel Günther's avatar Manuel Günther
Browse files

Updated MAP_GMMTrainer and removed the intermediate Python class.

parent ca845150
No related branches found
No related tags found
No related merge requests found
...@@ -17,28 +17,28 @@ static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /* ...@@ -17,28 +17,28 @@ static inline bool f(PyObject* o){return o != 0 && PyObject_IsTrue(o) > 0;} /*
static auto MAP_GMMTrainer_doc = bob::extension::ClassDoc( static auto MAP_GMMTrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX ".MAP_GMMTrainer", BOB_EXT_MODULE_PREFIX ".MAP_GMMTrainer",
"This class implements the maximum a posteriori M-step of the expectation-maximisation algorithm for a GMM Machine. The prior parameters are encoded in the form of a GMM (e.g. a universal background model). The EM algorithm thus performs GMM adaptation." "This class implements the maximum a posteriori M-step of the expectation-maximization algorithm for a GMM Machine. The prior parameters are encoded in the form of a GMM (e.g. a universal background model). The EM algorithm thus performs GMM adaptation."
).add_constructor( ).add_constructor(
bob::extension::FunctionDoc( bob::extension::FunctionDoc(
"__init__", "__init__",
"Creates a MAP_GMMTrainer", "Creates a MAP_GMMTrainer",
"", "Additionally to the copy constructor, there are two different ways to call this constructor, one using the ``relevance_factor`` and one using the ``alpha``, both which have the same signature. "
"Hence, the only way to differentiate the two functions is by using keyword arguments.",
true true
) )
.add_prototype("prior_gmm,relevance_factor, update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","") .add_prototype("prior_gmm, relevance_factor, [update_means], [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
.add_prototype("prior_gmm,alpha, update_means, [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","") .add_prototype("prior_gmm, alpha, [update_means], [update_variances], [update_weights], [mean_var_update_responsibilities_threshold]","")
.add_prototype("other","") .add_prototype("other","")
.add_parameter("prior_gmm", ":py:class:`bob.learn.em.GMMMachine`", "The prior GMM to be adapted (Universal Backgroud Model UBM).") .add_parameter("prior_gmm", ":py:class:`bob.learn.em.GMMMachine`", "The prior GMM to be adapted (Universal Background Model UBM).")
.add_parameter("reynolds_adaptation", "bool", "Will use the Reynolds adaptation procedure? See Eq (14) from [Reynolds2000]_") .add_parameter("relevance_factor", "float", "If set the Reynolds Adaptation procedure will be applied. See Eq (14) from [Reynolds2000]_")
.add_parameter("relevance_factor", "double", "If set the reynolds_adaptation parameters, will apply the Reynolds Adaptation procedure. See Eq (14) from [Reynolds2000]_") .add_parameter("alpha", "float", "Set directly the alpha parameter (Eq (14) from [Reynolds2000]_), ignoring zeroth order statistics as a weighting factor.")
.add_parameter("alpha", "double", "Set directly the alpha parameter (Eq (14) from [Reynolds2000]_), ignoring zeroth order statistics as a weighting factor.")
.add_parameter("update_means", "bool", "Update means on each iteration") .add_parameter("update_means", "bool", "[Default: ``True``] Update means on each iteration")
.add_parameter("update_variances", "bool", "Update variances on each iteration") .add_parameter("update_variances", "bool", "[Default: ``True``] Update variances on each iteration")
.add_parameter("update_weights", "bool", "Update weights on each iteration") .add_parameter("update_weights", "bool", "[Default: ``True``] Update weights on each iteration")
.add_parameter("mean_var_update_responsibilities_threshold", "float", "Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.") .add_parameter("mean_var_update_responsibilities_threshold", "float", "[Default: min_float] Threshold over the responsibilities of the Gaussians Equations 9.24, 9.25 of Bishop, `Pattern recognition and machine learning`, 2006 require a division by the responsibilities, which might be equal to zero because of numerical issue. This threshold is used to avoid such divisions.")
.add_parameter("other", ":py:class:`bob.learn.em.MAP_GMMTrainer`", "A MAP_GMMTrainer object to be copied.") .add_parameter("other", ":py:class:`bob.learn.em.MAP_GMMTrainer`", "A MAP_GMMTrainer object to be copied.")
); );
...@@ -81,24 +81,28 @@ static int PyBobLearnEMMAPGMMTrainer_init_base_trainer(PyBobLearnEMMAPGMMTrainer ...@@ -81,24 +81,28 @@ static int PyBobLearnEMMAPGMMTrainer_init_base_trainer(PyBobLearnEMMAPGMMTrainer
auto keyword_alpha_ = make_safe(keyword_alpha); auto keyword_alpha_ = make_safe(keyword_alpha);
//Here we have to select which keyword argument to read //Here we have to select which keyword argument to read
if (kwargs && PyDict_Contains(kwargs, keyword_relevance_factor) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!dO!|O!O!d", kwlist1, if (kwargs && PyDict_Contains(kwargs, keyword_relevance_factor)){
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!d|O!O!O!d", kwlist1,
&PyBobLearnEMGMMMachine_Type, &gmm_machine, &PyBobLearnEMGMMMachine_Type, &gmm_machine,
&aux, &aux,
&PyBool_Type, &update_means, &PyBool_Type, &update_means,
&PyBool_Type, &update_variances, &PyBool_Type, &update_variances,
&PyBool_Type, &update_weights, &PyBool_Type, &update_weights,
&mean_var_update_responsibilities_threshold))) &mean_var_update_responsibilities_threshold))
return -1;
reynolds_adaptation = true; reynolds_adaptation = true;
else if (kwargs && PyDict_Contains(kwargs, keyword_alpha) && (PyArg_ParseTupleAndKeywords(args, kwargs, "O!dO!|O!O!d", kwlist2, } else if (kwargs && PyDict_Contains(kwargs, keyword_alpha)){
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!d|O!O!O!d", kwlist2,
&PyBobLearnEMGMMMachine_Type, &gmm_machine, &PyBobLearnEMGMMMachine_Type, &gmm_machine,
&aux, &aux,
&PyBool_Type, &update_means, &PyBool_Type, &update_means,
&PyBool_Type, &update_variances, &PyBool_Type, &update_variances,
&PyBool_Type, &update_weights, &PyBool_Type, &update_weights,
&mean_var_update_responsibilities_threshold))) &mean_var_update_responsibilities_threshold))
return -1;
reynolds_adaptation = false; reynolds_adaptation = false;
else{ } else {
PyErr_Format(PyExc_RuntimeError, "%s. The second argument must be a keyword argument.", Py_TYPE(self)->tp_name); PyErr_Format(PyExc_RuntimeError, "%s. One of the two keyword arguments '%s' or '%s' must be present.", Py_TYPE(self)->tp_name, kwlist1[1], kwlist2[1]);
MAP_GMMTrainer_doc.print_usage(); MAP_GMMTrainer_doc.print_usage();
return -1; return -1;
} }
...@@ -473,6 +477,5 @@ bool init_BobLearnEMMAPGMMTrainer(PyObject* module) ...@@ -473,6 +477,5 @@ bool init_BobLearnEMMAPGMMTrainer(PyObject* module)
// add the type to the module // add the type to the module
Py_INCREF(&PyBobLearnEMMAPGMMTrainer_Type); Py_INCREF(&PyBobLearnEMMAPGMMTrainer_Type);
return PyModule_AddObject(module, "_MAP_GMMTrainer", (PyObject*)&PyBobLearnEMMAPGMMTrainer_Type) >= 0; return PyModule_AddObject(module, "MAP_GMMTrainer", (PyObject*)&PyBobLearnEMMAPGMMTrainer_Type) >= 0;
} }
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# Mon Jan 23 18:31:10 2015
#
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
from ._library import _MAP_GMMTrainer
import numpy
# define the class
class MAP_GMMTrainer(_MAP_GMMTrainer):
def __init__(self, prior_gmm, update_means=True, update_variances=False, update_weights=False, **kwargs):
"""
:py:class:`bob.learn.em.MAP_GMMTrainer` constructor
Keyword Parameters:
update_means
update_variances
update_weights
prior_gmm
A :py:class:`bob.learn.em.GMMMachine` to be adapted
convergence_threshold
Convergence threshold
max_iterations
Number of maximum iterations
converge_by_likelihood
Tells whether we compute log_likelihood as a convergence criteria, or not
alpha
Set directly the alpha parameter (Eq (14) from [Reynolds2000]_), ignoring zeroth order statistics as a weighting factor.
relevance_factor
If set the :py:class:`bob.learn.em.MAP_GMMTrainer.reynolds_adaptation` parameters, will apply the Reynolds Adaptation procedure. See Eq (14) from [Reynolds2000]_
"""
if kwargs.get('alpha')!=None:
alpha = kwargs.get('alpha')
_MAP_GMMTrainer.__init__(self, prior_gmm,alpha=alpha, update_means=update_means, update_variances=update_variances,update_weights=update_weights)
else:
relevance_factor = kwargs.get('relevance_factor')
_MAP_GMMTrainer.__init__(self, prior_gmm, relevance_factor=relevance_factor, update_means=update_means, update_variances=update_variances,update_weights=update_weights)
# copy the documentation from the base class
__doc__ = _MAP_GMMTrainer.__doc__
...@@ -7,11 +7,9 @@ import bob.learn.linear ...@@ -7,11 +7,9 @@ import bob.learn.linear
import bob.extension import bob.extension
bob.extension.load_bob_library('bob.learn.em', __file__) bob.extension.load_bob_library('bob.learn.em', __file__)
#from ._old_library import *
from ._library import * from ._library import *
from . import version from . import version
from .version import module as __version__ from .version import module as __version__
from .__MAP_gmm_trainer__ import *
from .train import * from .train import *
def ztnorm_same_value(vect_a, vect_b): def ztnorm_same_value(vect_a, vect_b):
......
...@@ -183,7 +183,7 @@ def test_gmm_MAP_3(): ...@@ -183,7 +183,7 @@ def test_gmm_MAP_3():
max_iter_gmm = 1 max_iter_gmm = 1
accuracy = 0.00001 accuracy = 0.00001
map_factor = 0.5 map_factor = 0.5
map_gmmtrainer = MAP_GMMTrainer(prior_gmm, alpha=map_factor, update_means=True, update_variances=False, update_weights=False, convergence_threshold=prior) map_gmmtrainer = MAP_GMMTrainer(prior_gmm, alpha=map_factor, update_means=True, update_variances=False, update_weights=False, mean_var_update_responsibilities_threshold=accuracy)
#map_gmmtrainer.max_iterations = max_iter_gmm #map_gmmtrainer.max_iterations = max_iter_gmm
#map_gmmtrainer.convergence_threshold = accuracy #map_gmmtrainer.convergence_threshold = accuracy
...@@ -192,7 +192,7 @@ def test_gmm_MAP_3(): ...@@ -192,7 +192,7 @@ def test_gmm_MAP_3():
# Train # Train
#map_gmmtrainer.train(gmm, ar) #map_gmmtrainer.train(gmm, ar)
bob.learn.em.train(map_gmmtrainer, gmm, ar, max_iterations = max_iter_gmm, convergence_threshold=accuracy) bob.learn.em.train(map_gmmtrainer, gmm, ar, max_iterations = max_iter_gmm, convergence_threshold=prior)
# Test results # Test results
# Load torch3vision reference # Load torch3vision reference
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment