Commit ee4ca4aa authored by Manuel Günther's avatar Manuel Günther Committed by Tiago de Freitas Pereira

Trial to implement EM with multiprocessing

parent 709dd06b
......@@ -19,7 +19,7 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \
import bob.learn.em
import bob.core
logger = bob.core.log.setup("bob.learn.em")
#, MAP_GMMTrainer
......@@ -113,6 +113,46 @@ def test_gmm_ML_2():
assert equals(gmm.weights, weightsML_ref, 1e-4)
def test_gmm_ML_parallel():
# Trains a GMMMachine with ML_GMMTrainer; compares to an old reference
ar ='dataNormalized.hdf5', __name__, path="../data/"))
# Initialize GMMMachine
gmm = GMMMachine(5, 45)
gmm.means ='meansAfterKMeans.hdf5', __name__, path="../data/")).astype('float64')
gmm.variances ='variancesAfterKMeans.hdf5', __name__, path="../data/")).astype('float64')
gmm.weights = numpy.exp('weightsAfterKMeans.hdf5', __name__, path="../data/")).astype('float64'))
threshold = 0.001
# Initialize ML Trainer
prior = 0.001
max_iter_gmm = 25
accuracy = 0.00001
ml_gmmtrainer = ML_GMMTrainer(True, True, True, prior)
# Run ML
import multiprocessing.pool
pool = multiprocessing.pool.ThreadPool(1)
# pool = multiprocessing.Pool(1)
bob.learn.em.train(ml_gmmtrainer, gmm, ar, max_iterations = max_iter_gmm, convergence_threshold=accuracy, pool=pool)
# Test results
# Load torch3vision reference
meansML_ref ='meansAfterML.hdf5', __name__, path="../data/"))
variancesML_ref ='variancesAfterML.hdf5', __name__, path="../data/"))
weightsML_ref ='weightsAfterML.hdf5', __name__, path="../data/"))
# Compare to current results
assert equals(gmm.means, meansML_ref, 3e-3)
assert equals(gmm.variances, variancesML_ref, 3e-3)
assert equals(gmm.weights, weightsML_ref, 1e-4)
def test_gmm_MAP_1():
This diff is collapsed.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment