Commit dce267f6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch...

Merge branch '25-ml_gmmtrainer-initialize-does-not-accept-a-random-generator-argument' into 'master'

Resolve "ML_GMMTrainer.initialize() does not accept a random generator argument"

Closes #25

See merge request !32
parents 262d7b62 c7350c6b
Pipeline #11934 passed with stages
in 19 minutes and 7 seconds
......@@ -55,6 +55,12 @@ def test_gmm_ML_1():
ar = bob.io.base.load(datafile("faithful.torch3_f64.hdf5", __name__, path="../data/"))
gmm = loadGMM()
# test rng handling
ml_gmmtrainer = ML_GMMTrainer(True, True, True)
rng = bob.core.random.mt19937(12345)
bob.learn.em.train(ml_gmmtrainer, gmm, ar, convergence_threshold=0.001, rng=rng)
gmm = loadGMM()
ml_gmmtrainer = ML_GMMTrainer(True, True, True)
#ml_gmmtrainer.train(gmm, ar)
bob.learn.em.train(ml_gmmtrainer, gmm, ar, convergence_threshold=0.001)
......@@ -114,6 +120,13 @@ def test_gmm_MAP_1():
ar = bob.io.base.load(datafile('faithful.torch3_f64.hdf5', __name__, path="../data/"))
# test with rng
rng = bob.core.random.mt19937(12345)
gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
map_gmmtrainer = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, prior_gmm=gmmprior, relevance_factor=4.)
bob.learn.em.train(map_gmmtrainer, gmm, ar, rng=rng)
gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
......@@ -253,9 +266,9 @@ def test_custom_trainer():
for i in range(0, 2):
assert (ar[i+1] == machine.means[i, :]).all()
def test_EMPCA():
# Tests our Probabilistic PCA trainer for linear machines for a simple
......@@ -294,5 +307,5 @@ def test_EMPCA():
T.e_step(m, ar)
T.m_step(m, ar)
llh2 = T.compute_likelihood(m)
assert abs(exp_llh2 - llh2) < 2e-4
assert abs(exp_llh2 - llh2) < 2e-4
......@@ -45,7 +45,9 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
# Initialization
if initialize:
if rng is not None:
if rng is not None and \
(not isinstance(trainer, (bob.learn.em.ML_GMMTrainer,
bob.learn.em.MAP_GMMTrainer))):
trainer.initialize(machine, data, rng)
else:
trainer.initialize(machine, data)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment