Commit 47a7c3b4 authored by Manuel Günther's avatar Manuel Günther
Browse files

Adapted RNG C-interface to new bob.core interface

parent 8823b15f
......@@ -166,7 +166,7 @@ static auto initialize = bob::extension::FunctionDoc(
"",
true
)
.add_prototype("linear_machine,data")
.add_prototype("linear_machine, data, [rng]")
.add_parameter("linear_machine", ":py:class:`bob.learn.linear.Machine`", "LinearMachine Object")
.add_parameter("data", "array_like <float, 2D>", "Input data")
.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
......@@ -186,8 +186,7 @@ static PyObject* PyBobLearnEMEMPCATrainer_initialize(PyBobLearnEMEMPCATrainerObj
auto data_ = make_safe(data);
if(rng){
boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
self->cxx->setRng(rng_cpy);
self->cxx->setRng(rng->rng);
}
......
......@@ -409,7 +409,7 @@ static auto initialize = bob::extension::FunctionDoc(
"",
true
)
.add_prototype("isv_base, stats, rng")
.add_prototype("isv_base, stats, [rng]")
.add_parameter("isv_base", ":py:class:`bob.learn.em.ISVBase`", "ISVBase Object")
.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "GMMStats Object")
.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
......@@ -428,8 +428,7 @@ static PyObject* PyBobLearnEMISVTrainer_initialize(PyBobLearnEMISVTrainerObject*
&PyBoostMt19937_Type, &rng)) return 0;
if(rng){
boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
self->cxx->setRng(rng_cpy);
self->cxx->setRng(rng->rng);
}
std::vector<std::vector<boost::shared_ptr<bob::learn::em::GMMStats> > > training_data;
......
......@@ -332,8 +332,7 @@ static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTraine
&PyBoostMt19937_Type, &rng)) return 0;
if(rng){
boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
self->cxx->setRng(rng_cpy);
self->cxx->setRng(rng->rng);
}
self->cxx->initialize(*ivector_machine->cxx);
......
......@@ -620,7 +620,7 @@ static auto initialize = bob::extension::FunctionDoc(
"",
true
)
.add_prototype("jfa_base,stats,rng")
.add_prototype("jfa_base, stats, [rng]")
.add_parameter("jfa_base", ":py:class:`bob.learn.em.JFABase`", "JFABase Object")
.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "GMMStats Object")
.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
......@@ -639,8 +639,7 @@ static PyObject* PyBobLearnEMJFATrainer_initialize(PyBobLearnEMJFATrainerObject*
&PyBoostMt19937_Type, &rng)) return 0;
if(rng){
boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
self->cxx->setRng(rng_cpy);
self->cxx->setRng(rng->rng);
}
std::vector<std::vector<boost::shared_ptr<bob::learn::em::GMMStats> > > training_data;
......
......@@ -313,7 +313,7 @@ static auto initialize = bob::extension::FunctionDoc(
"Data is split into as many chunks as there are means, then each mean is set to a random example within each chunk.",
true
)
.add_prototype("kmeans_machine,data, rng")
.add_prototype("kmeans_machine, data, [rng]")
.add_parameter("kmeans_machine", ":py:class:`bob.learn.em.KMeansMachine`", "KMeansMachine Object")
.add_parameter("data", "array_like <float, 2D>", "Input data")
.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
......@@ -349,8 +349,7 @@ static PyObject* PyBobLearnEMKMeansTrainer_initialize(PyBobLearnEMKMeansTrainerO
}
if(rng){
boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
self->cxx->setRng(rng_cpy);
self->cxx->setRng(rng->rng);
}
self->cxx->initialize(*kmeans_machine->cxx, *PyBlitzArrayCxx_AsBlitz<double,2>(data));
......
......@@ -424,7 +424,7 @@ static auto initialize = bob::extension::FunctionDoc(
"",
true
)
.add_prototype("plda_base,data,rng")
.add_prototype("plda_base, data, [rng]")
.add_parameter("plda_base", ":py:class:`bob.learn.em.PLDABase`", "PLDAMachine Object")
.add_parameter("data", "list", "")
.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
......@@ -445,8 +445,7 @@ static PyObject* PyBobLearnEMPLDATrainer_initialize(PyBobLearnEMPLDATrainerObjec
std::vector<blitz::Array<double,2> > data_vector;
if(list_as_vector(data ,data_vector)==0){
if(rng){
boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
self->cxx->setRng(rng_cpy);
self->cxx->setRng(rng->rng);
}
self->cxx->initialize(*plda_base->cxx, data_vector);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment