diff --git a/bob/learn/em/test/test_em.py b/bob/learn/em/test/test_em.py index 9acd5ce0c1ee3edd749cd360f345f069c03e883f..0b524f21a3b22c761de9549ac1114aab666845da 100644 --- a/bob/learn/em/test/test_em.py +++ b/bob/learn/em/test/test_em.py @@ -55,6 +55,12 @@ def test_gmm_ML_1(): ar = bob.io.base.load(datafile("faithful.torch3_f64.hdf5", __name__, path="../data/")) gmm = loadGMM() + # test rng handling + ml_gmmtrainer = ML_GMMTrainer(True, True, True) + rng = bob.core.random.mt19937(12345) + bob.learn.em.train(ml_gmmtrainer, gmm, ar, convergence_threshold=0.001, rng=rng) + + gmm = loadGMM() ml_gmmtrainer = ML_GMMTrainer(True, True, True) #ml_gmmtrainer.train(gmm, ar) bob.learn.em.train(ml_gmmtrainer, gmm, ar, convergence_threshold=0.001) @@ -114,6 +120,13 @@ def test_gmm_MAP_1(): ar = bob.io.base.load(datafile('faithful.torch3_f64.hdf5', __name__, path="../data/")) + # test with rng + rng = bob.core.random.mt19937(12345) + gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/"))) + gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/"))) + map_gmmtrainer = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, prior_gmm=gmmprior, relevance_factor=4.) + bob.learn.em.train(map_gmmtrainer, gmm, ar, rng=rng) + gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/"))) gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/"))) @@ -253,9 +266,9 @@ def test_custom_trainer(): for i in range(0, 2): assert (ar[i+1] == machine.means[i, :]).all() - - - + + + def test_EMPCA(): # Tests our Probabilistic PCA trainer for linear machines for a simple @@ -294,5 +307,5 @@ def test_EMPCA(): T.e_step(m, ar) T.m_step(m, ar) llh2 = T.compute_likelihood(m) - assert abs(exp_llh2 - llh2) < 2e-4 - + assert abs(exp_llh2 - llh2) < 2e-4 + diff --git a/bob/learn/em/train.py b/bob/learn/em/train.py index beb93373d7aeff03594ee81dd37fca6cf505f1b2..41114cca50bc21f43085a65dee3595f7332d2dac 100644 --- a/bob/learn/em/train.py +++ b/bob/learn/em/train.py @@ -45,7 +45,9 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, # Initialization if initialize: - if rng is not None: + if rng is not None and \ + (not isinstance(trainer, (bob.learn.em.ML_GMMTrainer, + bob.learn.em.MAP_GMMTrainer))): trainer.initialize(machine, data, rng) else: trainer.initialize(machine, data)