Commit 8b45dc70 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed the documentation and added the random number generator as a parameter in the general trainer

parent 09ac6e32
......@@ -14,7 +14,7 @@
/******************************************************************/
static auto EMPCATrainer_doc = bob::extension::ClassDoc(
BOB_EXT_MODULE_PREFIX "._EMPCATrainer",
BOB_EXT_MODULE_PREFIX ".EMPCATrainer",
""
).add_constructor(
......@@ -338,6 +338,6 @@ bool init_BobLearnEMEMPCATrainer(PyObject* module)
// add the type to the module
Py_INCREF(&PyBobLearnEMEMPCATrainer_Type);
return PyModule_AddObject(module, "_EMPCATrainer", (PyObject*)&PyBobLearnEMEMPCATrainer_Type) >= 0;
return PyModule_AddObject(module, "EMPCATrainer", (PyObject*)&PyBobLearnEMEMPCATrainer_Type) >= 0;
}
......@@ -287,7 +287,7 @@ static auto init_g_method = bob::extension::VariableDoc(
"init_g_method",
"str",
"The method used for the initialization of :math:`$G$`.",
"Possible values are: ('RANDOM_G', 'BETWEEN_SCATTER')"
"Possible values are: ('RANDOM_G', 'WITHIN_SCATTER')"
);
PyObject* PyBobLearnEMPLDATrainer_getGMethod(PyBobLearnEMPLDATrainerObject* self, void*) {
BOB_TRY
......
......@@ -7,27 +7,32 @@
import numpy
import bob.learn.em
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True):
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None):
"""
Trains a machine given a trainer and the proper data
**Parameters**:
trainer
trainer : one of :py:class:`KMeansTrainer`, :py:class:`MAP_GMMTrainer`, :py:class:`ML_GMMTrainer`, :py:class:`ISVTrainer`, :py:class:`IVectorTrainer`, :py:class:`PLDATrainer`, :py:class:`EMPCATrainer`
A trainer mechanism
machine
machine : one of :py:class:`KMeansMachine`, :py:class:`GMMMachine`, :py:class:`ISVBase`, :py:class:`IVectorMachine`, :py:class:`PLDAMachine`, :py:class:`bob.learn.linear.Machine`
A container machine
data
data : array_like <float, 2D>
The data to be trained
max_iterations
max_iterations : int
The maximum number of iterations to train a machine
convergence_threshold
convergence_threshold : float
The convergence threshold to train a machine. If None, the training procedure will stop with the iterations criteria
initialize
initialize : bool
If True, runs the initialization procedure
rng : :py:class:`bob.core.random.mt19937`
The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop
"""
#Initialization
if initialize:
trainer.initialize(machine, data)
if rng is not None:
trainer.initialize(machine, data, rng)
else:
trainer.initialize(machine, data)
trainer.eStep(machine, data)
average_output = 0
......
......@@ -12,9 +12,11 @@ This section includes information for using the pure Python API of
Classes
-------
Trainers
........
.. autosummary::
Trainers
--------
bob.learn.em.KMeansTrainer
bob.learn.em.ML_GMMTrainer
......@@ -23,9 +25,12 @@ Classes
bob.learn.em.JFATrainer
bob.learn.em.IVectorTrainer
bob.learn.em.PLDATrainer
bob.learn.em.EMPCATrainer
Machines
--------
Machines
........
.. autosummary::
bob.learn.em.KMeansMachine
bob.learn.em.Gaussian
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment