Commit 1eab675d authored by Manuel Günther's avatar Manuel Günther
Browse files

Added rng to IvectorTrainer initialize method

parent 7dbafb34
......@@ -121,6 +121,10 @@ class IVectorTrainer
{ bob::core::array::assertSameShape(acc, m_acc_Snormij);
m_acc_Snormij = acc; }
void setRng(boost::shared_ptr<boost::mt19937> rng){
m_rng = rng;
};
protected:
// Attributes
bool m_update_sigma;
......@@ -140,11 +144,11 @@ class IVectorTrainer
mutable blitz::Array<double,2> m_tmp_dt1;
mutable blitz::Array<double,2> m_tmp_tt1;
mutable blitz::Array<double,2> m_tmp_tt2;
/**
* @brief The random number generator for the inialization
*/
boost::shared_ptr<boost::mt19937> m_rng;
boost::shared_ptr<boost::mt19937> m_rng;
};
} } } // namespaces
......
......@@ -313,9 +313,10 @@ static auto initialize = bob::extension::FunctionDoc(
"",
true
)
.add_prototype("ivector_machine, stats")
.add_parameter("ivector_machine", ":py:class:`bob.learn.em.ISVBase`", "IVectorMachine Object")
.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "Ignored");
.add_prototype("ivector_machine, [stats], [rng]")
.add_parameter("ivector_machine", ":py:class:`bob.learn.em.IVectorMachine`", "IVectorMachine Object")
.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "Ignored")
.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY
......@@ -324,9 +325,15 @@ static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTraine
PyBobLearnEMIVectorMachineObject* ivector_machine = 0;
PyObject* stats = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &ivector_machine,
&PyList_Type, &stats)) return 0;
PyBoostMt19937Object* rng = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &ivector_machine,
&PyList_Type, &stats,
&PyBoostMt19937_Type, &rng)) return 0;
if(rng){
boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
self->cxx->setRng(rng_cpy);
}
self->cxx->initialize(*ivector_machine->cxx);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment