Skip to content
Snippets Groups Projects
Commit 1a3ac8d1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Also test randomization using the same seed for equality

parent b768b2f2
No related branches found
No related tags found
No related merge requests found
......@@ -976,7 +976,7 @@ static PyObject* PyBobLearnMLPMachine_Randomize
PyBoostMt19937Object* rng = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ddO!", kwlist,
&lower_bound, &upper_bound, &PyBoostMt19937_Check, &rng)) return 0;
&lower_bound, &upper_bound, &PyBoostMt19937_Type, &rng)) return 0;
if (rng) {
self->cxx->randomize(*rng->rng, lower_bound, upper_bound);
......
......@@ -18,6 +18,7 @@ from .test_utils import Machine as PythonMachine
import xbob.io
from xbob.io.test_utils import temporary_filename
from xbob.learn.activation import Logistic, HyperbolicTangent
from xbob.core.random import mt19937
def test_2in_1out():
......@@ -265,18 +266,37 @@ def test_randomization_margins():
assert (abs(k) <= 0.001).all()
assert (k != 0).any()
def test_randomness():
def test_randomness_different():
m1 = Machine((2,3,2))
m1.randomize()
for k in range(10):
for k in range(3):
time.sleep(0.1)
m2 = Machine((2,3,2))
m2.randomize()
for w1, w2 in zip(m1.weights, m2.weights):
nose.tools.eq_((w1 == w2).all(), False)
assert not (w1 == w2).all()
for b1, b2 in zip(m1.biases, m2.biases):
assert not (b1 == b2).all()
def test_randomness_same():
m1 = Machine((2,3,2))
rng = xbob.core.random.mt19937(0) #fixed seed
m1.randomize(rng=rng)
for k in range(3):
time.sleep(0.1)
m2 = Machine((2,3,2))
rng = xbob.core.random.mt19937(0) #fixed seed
m2.randomize(rng=rng)
for w1, w2 in zip(m1.weights, m2.weights):
assert (w1 == w2).all()
for b1, b2 in zip(m1.biases, m2.biases):
nose.tools.eq_((b1 == b2).all(), False)
assert (b1 == b2).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment