Skip to content
Snippets Groups Projects
Commit 1a3ac8d1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Also test randomization using the same seed for equality

parent b768b2f2
No related branches found
No related tags found
No related merge requests found
...@@ -976,7 +976,7 @@ static PyObject* PyBobLearnMLPMachine_Randomize ...@@ -976,7 +976,7 @@ static PyObject* PyBobLearnMLPMachine_Randomize
PyBoostMt19937Object* rng = 0; PyBoostMt19937Object* rng = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ddO!", kwlist, if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ddO!", kwlist,
&lower_bound, &upper_bound, &PyBoostMt19937_Check, &rng)) return 0; &lower_bound, &upper_bound, &PyBoostMt19937_Type, &rng)) return 0;
if (rng) { if (rng) {
self->cxx->randomize(*rng->rng, lower_bound, upper_bound); self->cxx->randomize(*rng->rng, lower_bound, upper_bound);
......
...@@ -18,6 +18,7 @@ from .test_utils import Machine as PythonMachine ...@@ -18,6 +18,7 @@ from .test_utils import Machine as PythonMachine
import xbob.io import xbob.io
from xbob.io.test_utils import temporary_filename from xbob.io.test_utils import temporary_filename
from xbob.learn.activation import Logistic, HyperbolicTangent from xbob.learn.activation import Logistic, HyperbolicTangent
from xbob.core.random import mt19937
def test_2in_1out(): def test_2in_1out():
...@@ -265,18 +266,37 @@ def test_randomization_margins(): ...@@ -265,18 +266,37 @@ def test_randomization_margins():
assert (abs(k) <= 0.001).all() assert (abs(k) <= 0.001).all()
assert (k != 0).any() assert (k != 0).any()
def test_randomness(): def test_randomness_different():
m1 = Machine((2,3,2)) m1 = Machine((2,3,2))
m1.randomize() m1.randomize()
for k in range(10): for k in range(3):
time.sleep(0.1) time.sleep(0.1)
m2 = Machine((2,3,2)) m2 = Machine((2,3,2))
m2.randomize() m2.randomize()
for w1, w2 in zip(m1.weights, m2.weights): for w1, w2 in zip(m1.weights, m2.weights):
nose.tools.eq_((w1 == w2).all(), False) assert not (w1 == w2).all()
for b1, b2 in zip(m1.biases, m2.biases):
assert not (b1 == b2).all()
def test_randomness_same():
m1 = Machine((2,3,2))
rng = xbob.core.random.mt19937(0) #fixed seed
m1.randomize(rng=rng)
for k in range(3):
time.sleep(0.1)
m2 = Machine((2,3,2))
rng = xbob.core.random.mt19937(0) #fixed seed
m2.randomize(rng=rng)
for w1, w2 in zip(m1.weights, m2.weights):
assert (w1 == w2).all()
for b1, b2 in zip(m1.biases, m2.biases): for b1, b2 in zip(m1.biases, m2.biases):
nose.tools.eq_((b1 == b2).all(), False) assert (b1 == b2).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment