From 1eab675dbd5142ab50fb870d062a280cc2e08d62 Mon Sep 17 00:00:00 2001
From: Manuel Guenther <manuel.guenther@idiap.ch>
Date: Fri, 6 Mar 2015 12:05:33 +0100
Subject: [PATCH] Added rng to IvectorTrainer initialize method

---
 .../em/include/bob.learn.em/IVectorTrainer.h  |  8 ++++++--
 bob/learn/em/ivector_trainer.cpp              | 19 +++++++++++++------
 2 files changed, 19 insertions(+), 8 deletions(-)

diff --git a/bob/learn/em/include/bob.learn.em/IVectorTrainer.h b/bob/learn/em/include/bob.learn.em/IVectorTrainer.h
index 4d90627..4f18226 100644
--- a/bob/learn/em/include/bob.learn.em/IVectorTrainer.h
+++ b/bob/learn/em/include/bob.learn.em/IVectorTrainer.h
@@ -121,6 +121,10 @@ class IVectorTrainer
     { bob::core::array::assertSameShape(acc, m_acc_Snormij);
       m_acc_Snormij = acc; }
 
+    void setRng(boost::shared_ptr<boost::mt19937> rng){
+      m_rng = rng;
+    };
+
   protected:
     // Attributes
     bool m_update_sigma;
@@ -140,11 +144,11 @@ class IVectorTrainer
     mutable blitz::Array<double,2> m_tmp_dt1;
     mutable blitz::Array<double,2> m_tmp_tt1;
     mutable blitz::Array<double,2> m_tmp_tt2;
-    
+
     /**
      * @brief The random number generator for the inialization
      */
-    boost::shared_ptr<boost::mt19937> m_rng;    
+    boost::shared_ptr<boost::mt19937> m_rng;
 };
 
 } } } // namespaces
diff --git a/bob/learn/em/ivector_trainer.cpp b/bob/learn/em/ivector_trainer.cpp
index bb8aa28..72c2e40 100644
--- a/bob/learn/em/ivector_trainer.cpp
+++ b/bob/learn/em/ivector_trainer.cpp
@@ -313,9 +313,10 @@ static auto initialize = bob::extension::FunctionDoc(
   "",
   true
 )
-.add_prototype("ivector_machine, stats")
-.add_parameter("ivector_machine", ":py:class:`bob.learn.em.ISVBase`", "IVectorMachine Object")
-.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "Ignored");
+.add_prototype("ivector_machine, [stats], [rng]")
+.add_parameter("ivector_machine", ":py:class:`bob.learn.em.IVectorMachine`", "IVectorMachine Object")
+.add_parameter("stats", ":py:class:`bob.learn.em.GMMStats`", "Ignored")
+.add_parameter("rng", ":py:class:`bob.core.random.mt19937`", "The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop.");
 static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTrainerObject* self, PyObject* args, PyObject* kwargs) {
   BOB_TRY
 
@@ -324,9 +325,15 @@ static PyObject* PyBobLearnEMIVectorTrainer_initialize(PyBobLearnEMIVectorTraine
 
   PyBobLearnEMIVectorMachineObject* ivector_machine = 0;
   PyObject* stats = 0;
-
-  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &ivector_machine,
-                                                                 &PyList_Type, &stats)) return 0;
+  PyBoostMt19937Object* rng = 0;
+
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O!O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &ivector_machine,
+                                                                    &PyList_Type, &stats,
+                                                                    &PyBoostMt19937_Type, &rng)) return 0;
+  if(rng){
+    boost::shared_ptr<boost::mt19937> rng_cpy = (boost::shared_ptr<boost::mt19937>)new boost::mt19937(*rng->rng);
+    self->cxx->setRng(rng_cpy);
+  }
 
   self->cxx->initialize(*ivector_machine->cxx);
 
-- 
GitLab