Tested kmeans parallel

parent a191b9c4
Pipeline #29929 passed with stage
in 12 minutes and 15 seconds
...@@ -177,6 +177,41 @@ def test_kmeans_b(): ...@@ -177,6 +177,41 @@ def test_kmeans_b():
assert (numpy.isnan(machine.means).any()) == False assert (numpy.isnan(machine.means).any()) == False
def test_kmeans_parallel():
# Trains a KMeansMachine
(arStd, std) = NormalizeStdArray(datafile("faithful.torch3.hdf5", __name__, path="../data/"))
machine = KMeansMachine(2, 2)
trainer = KMeansTrainer()
# trainer.seed = 1337
import multiprocessing.pool
pool = multiprocessing.pool.ThreadPool(3)
bob.learn.em.train(trainer, machine, arStd, convergence_threshold=0.001, pool = pool)
[variances, weights] = machine.get_variances_and_weights_for_each_cluster(arStd)
means = numpy.array(machine.means)
variances = numpy.array(variances)
multiplyVectorsByFactors(means, std)
multiplyVectorsByFactors(variances, std ** 2)
gmmWeights = bob.io.base.load(datafile('gmm.init_weights.hdf5', __name__, path="../data/"))
gmmMeans = bob.io.base.load(datafile('gmm.init_means.hdf5', __name__, path="../data/"))
gmmVariances = bob.io.base.load(datafile('gmm.init_variances.hdf5', __name__, path="../data/"))
if (means[0, 0] < means[1, 0]):
means = flipRows(means)
variances = flipRows(variances)
weights = flipRows(weights)
assert equals(means, gmmMeans, 1e-3)
assert equals(weights, gmmWeights, 1e-3)
assert equals(variances, gmmVariances, 1e-3)
def test_trainer_execption(): def test_trainer_execption():
from nose.tools import assert_raises from nose.tools import assert_raises
......
...@@ -11,7 +11,7 @@ import logging ...@@ -11,7 +11,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _set_average(trainer, trainers, data): def _set_average(trainer, trainers, machine, data):
"""_set_average(trainer, data) -> None """_set_average(trainer, data) -> None
This function computes the average of the given data and sets it to the given machine. This function computes the average of the given data and sets it to the given machine.
...@@ -35,15 +35,14 @@ def _set_average(trainer, trainers, data): ...@@ -35,15 +35,14 @@ def _set_average(trainer, trainers, data):
if isinstance(trainer, KMeansTrainer): if isinstance(trainer, KMeansTrainer):
# K-Means statistics # K-Means statistics
trainer.zeroeth_order_statistics = numpy.zeros(trainer.zeroeth_order_statistics.shape) trainer.reset_accumulators(machine)
trainer.first_order_statistics = numpy.zeros(trainer.first_order_statistics.shape) for t in trainers:
trainer.average_min_distance = 0. trainer.zeroeth_order_statistics = trainer.zeroeth_order_statistics + t.zeroeth_order_statistics
trainer.first_order_statistics = trainer.first_order_statistics + t.first_order_statistics
trainer.average_min_distance = trainer.average_min_distance + t.average_min_distance
for t in trainer: #trainer.average_min_distance /= sum(d.shape[0] for d in data)
trainer.zeroeth_order_statistics += t.zeroeth_order_statistics trainer.average_min_distance /= data.shape[0]
trainer.first_order_statistics += t.first_order_statistics
trainer.average_min_distance += trainer.average_min_distance
trainer.average_min_distance /= sum(d.shape[0] for d in data)
elif isinstance(trainer, (ML_GMMTrainer, MAP_GMMTrainer)): elif isinstance(trainer, (ML_GMMTrainer, MAP_GMMTrainer)):
# GMM statistics # GMM statistics
...@@ -128,7 +127,7 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, ...@@ -128,7 +127,7 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
# call the parallel processes # call the parallel processes
pool.map(_parallel_e_step, zip(trainers, machines, split_data)) pool.map(_parallel_e_step, zip(trainers, machines, split_data))
# update the trainer with the data of the other trainers # update the trainer with the data of the other trainers
_set_average(trainer, trainers, data) _set_average(trainer, trainers, machine, data)
# Initialization # Initialization
if initialize: if initialize:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment