Skip to content
Snippets Groups Projects
Commit 8b45dc70 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed the documentation and added the random number generator as a parameter in the general trainer

parent 09ac6e32
No related branches found
No related tags found
No related merge requests found
......@@ -14,7 +14,7 @@
/******************************************************************/
static auto EMPCATrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX "._EMPCATrainer",
BOB_EXT_MODULE_PREFIX ".EMPCATrainer",
""
).add_constructor(
......@@ -338,6 +338,6 @@ bool init_BobLearnEMEMPCATrainer(PyObject* module)
// add the type to the module
Py_INCREF(&PyBobLearnEMEMPCATrainer_Type);
return PyModule_AddObject(module, "_EMPCATrainer", (PyObject*)&PyBobLearnEMEMPCATrainer_Type) >= 0;
return PyModule_AddObject(module, "EMPCATrainer", (PyObject*)&PyBobLearnEMEMPCATrainer_Type) >= 0;
}
......@@ -287,7 +287,7 @@ static auto init_g_method = bob::extension::VariableDoc(
"init_g_method",
"str",
"The method used for the initialization of :math:`$G$`.",
"Possible values are: ('RANDOM_G', 'BETWEEN_SCATTER')"
"Possible values are: ('RANDOM_G', 'WITHIN_SCATTER')"
);
PyObject* PyBobLearnEMPLDATrainer_getGMethod(PyBobLearnEMPLDATrainerObject* self, void*) {
BOB_TRY
......
......@@ -7,27 +7,32 @@
import numpy
import bob.learn.em
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True):
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None):
"""
Trains a machine given a trainer and the proper data
**Parameters**:
trainer
trainer : one of :py:class:`KMeansTrainer`, :py:class:`MAP_GMMTrainer`, :py:class:`ML_GMMTrainer`, :py:class:`ISVTrainer`, :py:class:`IVectorTrainer`, :py:class:`PLDATrainer`, :py:class:`EMPCATrainer`
A trainer mechanism
machine
machine : one of :py:class:`KMeansMachine`, :py:class:`GMMMachine`, :py:class:`ISVBase`, :py:class:`IVectorMachine`, :py:class:`PLDAMachine`, :py:class:`bob.learn.linear.Machine`
A container machine
data
data : array_like <float, 2D>
The data to be trained
max_iterations
max_iterations : int
The maximum number of iterations to train a machine
convergence_threshold
convergence_threshold : float
The convergence threshold to train a machine. If None, the training procedure will stop with the iterations criteria
initialize
initialize : bool
If True, runs the initialization procedure
rng : :py:class:`bob.core.random.mt19937`
The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop
"""
#Initialization
if initialize:
trainer.initialize(machine, data)
if rng is not None:
trainer.initialize(machine, data, rng)
else:
trainer.initialize(machine, data)
trainer.eStep(machine, data)
average_output = 0
......
......@@ -12,9 +12,11 @@ This section includes information for using the pure Python API of
Classes
-------
Trainers
........
.. autosummary::
Trainers
--------
bob.learn.em.KMeansTrainer
bob.learn.em.ML_GMMTrainer
......@@ -23,9 +25,12 @@ Classes
bob.learn.em.JFATrainer
bob.learn.em.IVectorTrainer
bob.learn.em.PLDATrainer
bob.learn.em.EMPCATrainer
Machines
--------
Machines
........
.. autosummary::
bob.learn.em.KMeansMachine
bob.learn.em.Gaussian
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment