Commit f4c7441e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'threading' into 'master'

Adds a threading option to the GMM algorithm

See merge request !22
parents ff4dd87b 0b830148
Pipeline #30791 passed with stages
in 25 minutes and 40 seconds
......@@ -10,6 +10,7 @@ import bob.learn.em
import numpy
from bob.bio.base.algorithm import Algorithm
from multiprocessing.pool import ThreadPool
import logging
logger = logging.getLogger("bob.bio.gmm")
......@@ -36,7 +37,8 @@ class GMM (Algorithm):
responsibility_threshold = 0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance.
INIT_SEED = 5489,
# scoring
scoring_function = bob.learn.em.linear_scoring
scoring_function = bob.learn.em.linear_scoring,
n_threads=None,
):
"""Initializes the local UBM-GMM tool chain with the given file selector object"""
......@@ -79,6 +81,8 @@ class GMM (Algorithm):
self.rng = bob.core.random.mt19937(self.init_seed)
self.responsibility_threshold = responsibility_threshold
self.scoring_function = scoring_function
self.n_threads = n_threads
self.pool = None
self.ubm = None
self.kmeans_trainer = bob.learn.em.KMeansTrainer()
......@@ -101,6 +105,8 @@ class GMM (Algorithm):
def train_ubm(self, array):
logger.debug(" .... Training with %d feature vectors", array.shape[0])
if self.n_threads is not None:
self.pool = ThreadPool(self.n_threads)
# Computes input size
input_size = array.shape[1]
......@@ -113,9 +119,12 @@ class GMM (Algorithm):
# Trains using the KMeansTrainer
logger.info(" -> Training K-Means")
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, rng=self.rng)
bob.learn.em.train(
self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array)
means = kmeans.means
......@@ -128,10 +137,12 @@ class GMM (Algorithm):
# Trains the GMM
logger.info(" -> Training GMM")
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, rng=self.rng)
bob.learn.em.train(
self.ubm_trainer, self.ubm, array, self.gmm_training_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
def save_ubm(self, projector_file):
"""Save projector to file"""
......@@ -140,7 +151,6 @@ class GMM (Algorithm):
hdf5 = projector_file if isinstance(projector_file, bob.io.base.HDF5File) else bob.io.base.HDF5File(projector_file, 'w')
self.ubm.save(hdf5)
def train_projector(self, train_features, projector_file):
"""Computes the Universal Background Model from the training ("world") data"""
[self._check_feature(feature) for feature in train_features]
......@@ -154,7 +164,6 @@ class GMM (Algorithm):
self.save_ubm(projector_file)
#######################################################
############## GMM training using UBM #################
......@@ -204,7 +213,10 @@ class GMM (Algorithm):
gmm = bob.learn.em.GMMMachine(self.ubm)
gmm.set_variance_thresholds(self.variance_threshold)
bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, rng=self.rng)
bob.learn.em.train(
self.enroll_trainer, gmm, array, self.gmm_enroll_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
return gmm
def enroll(self, feature_arrays):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment