Skip to content
Snippets Groups Projects
Commit f4c7441e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'threading' into 'master'

Adds a threading option to the GMM algorithm

See merge request !22
parents ff4dd87b 0b830148
No related branches found
No related tags found
1 merge request!22Adds a threading option to the GMM algorithm
Pipeline #30791 passed
...@@ -10,6 +10,7 @@ import bob.learn.em ...@@ -10,6 +10,7 @@ import bob.learn.em
import numpy import numpy
from bob.bio.base.algorithm import Algorithm from bob.bio.base.algorithm import Algorithm
from multiprocessing.pool import ThreadPool
import logging import logging
logger = logging.getLogger("bob.bio.gmm") logger = logging.getLogger("bob.bio.gmm")
...@@ -36,7 +37,8 @@ class GMM (Algorithm): ...@@ -36,7 +37,8 @@ class GMM (Algorithm):
responsibility_threshold = 0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance. responsibility_threshold = 0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance.
INIT_SEED = 5489, INIT_SEED = 5489,
# scoring # scoring
scoring_function = bob.learn.em.linear_scoring scoring_function = bob.learn.em.linear_scoring,
n_threads=None,
): ):
"""Initializes the local UBM-GMM tool chain with the given file selector object""" """Initializes the local UBM-GMM tool chain with the given file selector object"""
...@@ -79,6 +81,8 @@ class GMM (Algorithm): ...@@ -79,6 +81,8 @@ class GMM (Algorithm):
self.rng = bob.core.random.mt19937(self.init_seed) self.rng = bob.core.random.mt19937(self.init_seed)
self.responsibility_threshold = responsibility_threshold self.responsibility_threshold = responsibility_threshold
self.scoring_function = scoring_function self.scoring_function = scoring_function
self.n_threads = n_threads
self.pool = None
self.ubm = None self.ubm = None
self.kmeans_trainer = bob.learn.em.KMeansTrainer() self.kmeans_trainer = bob.learn.em.KMeansTrainer()
...@@ -101,6 +105,8 @@ class GMM (Algorithm): ...@@ -101,6 +105,8 @@ class GMM (Algorithm):
def train_ubm(self, array): def train_ubm(self, array):
logger.debug(" .... Training with %d feature vectors", array.shape[0]) logger.debug(" .... Training with %d feature vectors", array.shape[0])
if self.n_threads is not None:
self.pool = ThreadPool(self.n_threads)
# Computes input size # Computes input size
input_size = array.shape[1] input_size = array.shape[1]
...@@ -113,9 +119,12 @@ class GMM (Algorithm): ...@@ -113,9 +119,12 @@ class GMM (Algorithm):
# Trains using the KMeansTrainer # Trains using the KMeansTrainer
logger.info(" -> Training K-Means") logger.info(" -> Training K-Means")
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution. # Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed) self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, rng=self.rng) bob.learn.em.train(
self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array) variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array)
means = kmeans.means means = kmeans.means
...@@ -128,10 +137,12 @@ class GMM (Algorithm): ...@@ -128,10 +137,12 @@ class GMM (Algorithm):
# Trains the GMM # Trains the GMM
logger.info(" -> Training GMM") logger.info(" -> Training GMM")
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution. # Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed) self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, rng=self.rng) bob.learn.em.train(
self.ubm_trainer, self.ubm, array, self.gmm_training_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
def save_ubm(self, projector_file): def save_ubm(self, projector_file):
"""Save projector to file""" """Save projector to file"""
...@@ -140,7 +151,6 @@ class GMM (Algorithm): ...@@ -140,7 +151,6 @@ class GMM (Algorithm):
hdf5 = projector_file if isinstance(projector_file, bob.io.base.HDF5File) else bob.io.base.HDF5File(projector_file, 'w') hdf5 = projector_file if isinstance(projector_file, bob.io.base.HDF5File) else bob.io.base.HDF5File(projector_file, 'w')
self.ubm.save(hdf5) self.ubm.save(hdf5)
def train_projector(self, train_features, projector_file): def train_projector(self, train_features, projector_file):
"""Computes the Universal Background Model from the training ("world") data""" """Computes the Universal Background Model from the training ("world") data"""
[self._check_feature(feature) for feature in train_features] [self._check_feature(feature) for feature in train_features]
...@@ -154,7 +164,6 @@ class GMM (Algorithm): ...@@ -154,7 +164,6 @@ class GMM (Algorithm):
self.save_ubm(projector_file) self.save_ubm(projector_file)
####################################################### #######################################################
############## GMM training using UBM ################# ############## GMM training using UBM #################
...@@ -204,7 +213,10 @@ class GMM (Algorithm): ...@@ -204,7 +213,10 @@ class GMM (Algorithm):
gmm = bob.learn.em.GMMMachine(self.ubm) gmm = bob.learn.em.GMMMachine(self.ubm)
gmm.set_variance_thresholds(self.variance_threshold) gmm.set_variance_thresholds(self.variance_threshold)
bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, rng=self.rng) bob.learn.em.train(
self.enroll_trainer, gmm, array, self.gmm_enroll_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
return gmm return gmm
def enroll(self, feature_arrays): def enroll(self, feature_arrays):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment