From 1a3ac8d13831b9c8296324dc1ebe8ae08bb6b159 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 29 Apr 2014 09:13:24 +0200
Subject: [PATCH] Also test randomization using the same seed for equality

---
 xbob/learn/mlp/machine.cpp     |  2 +-
 xbob/learn/mlp/test_machine.py | 28 ++++++++++++++++++++++++----
 2 files changed, 25 insertions(+), 5 deletions(-)

diff --git a/xbob/learn/mlp/machine.cpp b/xbob/learn/mlp/machine.cpp
index ed3672d..b6ba2b1 100644
--- a/xbob/learn/mlp/machine.cpp
+++ b/xbob/learn/mlp/machine.cpp
@@ -976,7 +976,7 @@ static PyObject* PyBobLearnMLPMachine_Randomize
   PyBoostMt19937Object* rng = 0;
 
   if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ddO!", kwlist,
-        &lower_bound, &upper_bound, &PyBoostMt19937_Check, &rng)) return 0;
+        &lower_bound, &upper_bound, &PyBoostMt19937_Type, &rng)) return 0;
 
   if (rng) {
     self->cxx->randomize(*rng->rng, lower_bound, upper_bound);
diff --git a/xbob/learn/mlp/test_machine.py b/xbob/learn/mlp/test_machine.py
index f8a91f4..aceef16 100644
--- a/xbob/learn/mlp/test_machine.py
+++ b/xbob/learn/mlp/test_machine.py
@@ -18,6 +18,7 @@ from .test_utils import Machine as PythonMachine
 import xbob.io
 from xbob.io.test_utils import temporary_filename
 from xbob.learn.activation import Logistic, HyperbolicTangent
+from xbob.core.random import mt19937
 
 def test_2in_1out():
 
@@ -265,18 +266,37 @@ def test_randomization_margins():
       assert (abs(k) <= 0.001).all()
       assert (k != 0).any()
 
-def test_randomness():
+def test_randomness_different():
 
   m1 = Machine((2,3,2))
   m1.randomize()
 
-  for k in range(10):
+  for k in range(3):
     time.sleep(0.1)
     m2 = Machine((2,3,2))
     m2.randomize()
 
     for w1, w2 in zip(m1.weights, m2.weights):
-      nose.tools.eq_((w1 == w2).all(), False)
+      assert not (w1 == w2).all()
+
+    for b1, b2 in zip(m1.biases, m2.biases):
+      assert not (b1 == b2).all()
+
+def test_randomness_same():
+
+  m1 = Machine((2,3,2))
+  rng = xbob.core.random.mt19937(0) #fixed seed
+  m1.randomize(rng=rng)
+
+  for k in range(3):
+    time.sleep(0.1)
+
+    m2 = Machine((2,3,2))
+    rng = xbob.core.random.mt19937(0) #fixed seed
+    m2.randomize(rng=rng)
+
+    for w1, w2 in zip(m1.weights, m2.weights):
+      assert (w1 == w2).all()
 
     for b1, b2 in zip(m1.biases, m2.biases):
-      nose.tools.eq_((b1 == b2).all(), False)
+      assert (b1 == b2).all()
-- 
GitLab