Skip to content
Snippets Groups Projects

Resolve "ML_GMMTrainer.initialize() does not accept a random generator argument"

2 files
+ 21
6
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -55,6 +55,12 @@ def test_gmm_ML_1():
ar = bob.io.base.load(datafile("faithful.torch3_f64.hdf5", __name__, path="../data/"))
gmm = loadGMM()
# test rng handling
ml_gmmtrainer = ML_GMMTrainer(True, True, True)
rng = bob.core.random.mt19937(12345)
bob.learn.em.train(ml_gmmtrainer, gmm, ar, convergence_threshold=0.001, rng=rng)
gmm = loadGMM()
ml_gmmtrainer = ML_GMMTrainer(True, True, True)
#ml_gmmtrainer.train(gmm, ar)
bob.learn.em.train(ml_gmmtrainer, gmm, ar, convergence_threshold=0.001)
@@ -114,6 +120,13 @@ def test_gmm_MAP_1():
ar = bob.io.base.load(datafile('faithful.torch3_f64.hdf5', __name__, path="../data/"))
# test with rng
rng = bob.core.random.mt19937(12345)
gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
map_gmmtrainer = MAP_GMMTrainer(update_means=True, update_variances=False, update_weights=False, prior_gmm=gmmprior, relevance_factor=4.)
bob.learn.em.train(map_gmmtrainer, gmm, ar, rng=rng)
gmm = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
gmmprior = GMMMachine(bob.io.base.HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/")))
@@ -253,9 +266,9 @@ def test_custom_trainer():
for i in range(0, 2):
assert (ar[i+1] == machine.means[i, :]).all()
def test_EMPCA():
# Tests our Probabilistic PCA trainer for linear machines for a simple
@@ -294,5 +307,5 @@ def test_EMPCA():
T.e_step(m, ar)
T.m_step(m, ar)
llh2 = T.compute_likelihood(m)
assert abs(exp_llh2 - llh2) < 2e-4
assert abs(exp_llh2 - llh2) < 2e-4
Loading