From 8b45dc7037f970481cf39104508662f2795594cc Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Thu, 5 Mar 2015 14:19:30 +0100 Subject: [PATCH] Fixed the documentation and added the random number generator as a parameter in the general trainer --- bob/learn/em/empca_trainer.cpp | 4 ++-- bob/learn/em/plda_trainer.cpp | 2 +- bob/learn/em/train.py | 21 +++++++++++++-------- doc/py_api.rst | 13 +++++++++---- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/bob/learn/em/empca_trainer.cpp b/bob/learn/em/empca_trainer.cpp index a24af2f..8c74b43 100644 --- a/bob/learn/em/empca_trainer.cpp +++ b/bob/learn/em/empca_trainer.cpp @@ -14,7 +14,7 @@ /******************************************************************/ static auto EMPCATrainer_doc = bob::extension::ClassDoc( - BOB_EXT_MODULE_PREFIX "._EMPCATrainer", + BOB_EXT_MODULE_PREFIX ".EMPCATrainer", "" ).add_constructor( @@ -338,6 +338,6 @@ bool init_BobLearnEMEMPCATrainer(PyObject* module) // add the type to the module Py_INCREF(&PyBobLearnEMEMPCATrainer_Type); - return PyModule_AddObject(module, "_EMPCATrainer", (PyObject*)&PyBobLearnEMEMPCATrainer_Type) >= 0; + return PyModule_AddObject(module, "EMPCATrainer", (PyObject*)&PyBobLearnEMEMPCATrainer_Type) >= 0; } diff --git a/bob/learn/em/plda_trainer.cpp b/bob/learn/em/plda_trainer.cpp index c3a1e18..cfc25d6 100644 --- a/bob/learn/em/plda_trainer.cpp +++ b/bob/learn/em/plda_trainer.cpp @@ -287,7 +287,7 @@ static auto init_g_method = bob::extension::VariableDoc( "init_g_method", "str", "The method used for the initialization of :math:`$G$`.", - "Possible values are: ('RANDOM_G', 'BETWEEN_SCATTER')" + "Possible values are: ('RANDOM_G', 'WITHIN_SCATTER')" ); PyObject* PyBobLearnEMPLDATrainer_getGMethod(PyBobLearnEMPLDATrainerObject* self, void*) { BOB_TRY diff --git a/bob/learn/em/train.py b/bob/learn/em/train.py index 0c4e5ac..1fe9a42 100644 --- a/bob/learn/em/train.py +++ b/bob/learn/em/train.py @@ -7,27 +7,32 @@ import numpy import bob.learn.em -def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True): +def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None): """ Trains a machine given a trainer and the proper data **Parameters**: - trainer + trainer : one of :py:class:`KMeansTrainer`, :py:class:`MAP_GMMTrainer`, :py:class:`ML_GMMTrainer`, :py:class:`ISVTrainer`, :py:class:`IVectorTrainer`, :py:class:`PLDATrainer`, :py:class:`EMPCATrainer` A trainer mechanism - machine + machine : one of :py:class:`KMeansMachine`, :py:class:`GMMMachine`, :py:class:`ISVBase`, :py:class:`IVectorMachine`, :py:class:`PLDAMachine`, :py:class:`bob.learn.linear.Machine` A container machine - data + data : array_like <float, 2D> The data to be trained - max_iterations + max_iterations : int The maximum number of iterations to train a machine - convergence_threshold + convergence_threshold : float The convergence threshold to train a machine. If None, the training procedure will stop with the iterations criteria - initialize + initialize : bool If True, runs the initialization procedure + rng : :py:class:`bob.core.random.mt19937` + The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop """ #Initialization if initialize: - trainer.initialize(machine, data) + if rng is not None: + trainer.initialize(machine, data, rng) + else: + trainer.initialize(machine, data) trainer.eStep(machine, data) average_output = 0 diff --git a/doc/py_api.rst b/doc/py_api.rst index 3585d93..9bf055a 100644 --- a/doc/py_api.rst +++ b/doc/py_api.rst @@ -12,9 +12,11 @@ This section includes information for using the pure Python API of Classes ------- + +Trainers +........ + .. autosummary:: - Trainers - -------- bob.learn.em.KMeansTrainer bob.learn.em.ML_GMMTrainer @@ -23,9 +25,12 @@ Classes bob.learn.em.JFATrainer bob.learn.em.IVectorTrainer bob.learn.em.PLDATrainer + bob.learn.em.EMPCATrainer - Machines - -------- +Machines +........ + +.. autosummary:: bob.learn.em.KMeansMachine bob.learn.em.Gaussian -- GitLab