From 1eab675dbd5142ab50fb870d062a280cc2e08d62 Mon Sep 17 00:00:00 2001 From: Manuel Guenther <manuel.guenther@idiap.ch> Date: Fri, 6 Mar 2015 12:05:33 +0100 Subject: [PATCH] Added rng to IvectorTrainer initialize method --- .../em/include/bob.learn.em/IVectorTrainer.h | 8 ++++++-- bob/learn/em/ivector_trainer.cpp | 19 +++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/bob/learn/em/include/bob.learn.em/IVectorTrainer.h b/bob/learn/em/include/bob.learn.em/IVectorTrainer.h index 4d90627..4f18226 100644 --- a/bob/learn/em/include/bob.learn.em/IVectorTrainer.h +++ b/bob/learn/em/include/bob.learn.em/IVectorTrainer.h @@ -121,6 +121,10 @@ class IVectorTrainer { bob::core::array::assertSameShape(acc, m_acc_Snormij); m_acc_Snormij = acc; } + void setRng(boost::shared_ptr<boost::mt19937> rng){ + m_rng = rng; + }; + protected: // Attributes bool m_update_sigma; @@ -140,11 +144,11 @@ class IVectorTrainer mutable blitz::Array<double,2> m_tmp_dt1; mutable blitz::Array<double,2> m_tmp_tt1; mutable blitz::Array<double,2> m_tmp_tt2; - + /** * @brief The random number generator for the inialization */ - boost::shared_ptr<boost::mt19937> m_rng; + boost::shared_ptr<boost::mt19937> m_rng; }; } } } // namespaces diff --git a/bob/learn/em/ivector_trainer.cpp b/bob/learn/em/ivector_trainer.cpp index bb8aa28..72c2e40 100644 --- a/bob/learn/em/ivector_trainer.cpp +++ b/bob/learn/em/ivector_trainer.cpp @@ -313,9 +313,10 @@ static auto initialize = bob::extension::FunctionDoc( "", true ) -.add_prototype("ivector_machine, stats") -.add_parameter("ivector_machine", ":py:class:`bob.learn.em.ISVBase`", "IVectorMachine Object") -.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "Ignored"); +.add_prototype("ivector_machine, [stats], [rng]") +.add_parameter("ivector_machine", ":py:class:`bob.learn.em.IVectorMachine`", "IVectorMachine Object") +.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "Ignored") +.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop."); static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTrainerObject* self, PyObject* args, PyObject* kwargs) { BOB_TRY @@ -324,9 +325,15 @@ static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTraine PyBobLearnEMIVectorMachineObject* ivector_machine = 0; PyObject* stats = 0; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &ivector_machine, - &PyList_Type, &stats)) return 0; + PyBoostMt19937Object* rng = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &ivector_machine, + &PyList_Type, &stats, + &PyBoostMt19937_Type, &rng)) return 0; + if(rng){ + boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng); + self->cxx->setRng(rng_cpy); + } self->cxx->initialize(*ivector_machine->cxx); -- GitLab