Skip to content
Snippets Groups Projects
Commit 1eab675d authored by Manuel Günther's avatar Manuel Günther
Browse files

Added rng to IvectorTrainer initialize method

parent 7dbafb34
No related branches found
No related tags found
No related merge requests found
......@@ -121,6 +121,10 @@ class IVectorTrainer
{ bob::core::array::assertSameShape(acc, m_acc_Snormij);
m_acc_Snormij = acc; }
void setRng(boost::shared_ptr<boost::mt19937> rng){
m_rng = rng;
};
protected:
// Attributes
bool m_update_sigma;
......@@ -140,11 +144,11 @@ class IVectorTrainer
mutable blitz::Array<double,2> m_tmp_dt1;
mutable blitz::Array<double,2> m_tmp_tt1;
mutable blitz::Array<double,2> m_tmp_tt2;
/**
* @brief The random number generator for the inialization
*/
boost::shared_ptr<boost::mt19937> m_rng;
boost::shared_ptr<boost::mt19937> m_rng;
};
} } } // namespaces
......
......@@ -313,9 +313,10 @@ static auto initialize = bob::extension::FunctionDoc(
"",
true
)
.add_prototype("ivector_machine, stats")
.add_parameter("ivector_machine", ":py:class:`bob.learn.em.ISVBase`", "IVectorMachine Object")
.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "Ignored");
.add_prototype("ivector_machine, [stats], [rng]")
.add_parameter("ivector_machine", ":py:class:`bob.learn.em.IVectorMachine`", "IVectorMachine Object")
.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "Ignored")
.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
......@@ -324,9 +325,15 @@ static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTraine
PyBobLearnEMIVectorMachineObject* ivector_machine = 0;
PyObject* stats = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &ivector_machine,
&PyList_Type, &stats)) return 0;
PyBoostMt19937Object* rng = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &ivector_machine,
&PyList_Type, &stats,
&PyBoostMt19937_Type, &rng)) return 0;
if(rng){
boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
self->cxx->setRng(rng_cpy);
}
self->cxx->initialize(*ivector_machine->cxx);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment