Skip to content
Snippets Groups Projects
Commit 0b830148 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Adds a threading option to the GMM algorithm

parent ff4dd87b
No related branches found
No related tags found
1 merge request!22Adds a threading option to the GMM algorithm
Pipeline #30733 passed
......@@ -10,6 +10,7 @@ import bob.learn.em
import numpy
from bob.bio.base.algorithm import Algorithm
from multiprocessing.pool import ThreadPool
import logging
logger = logging.getLogger("bob.bio.gmm")
......@@ -36,7 +37,8 @@ class GMM (Algorithm):
responsibility_threshold = 0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance.
INIT_SEED = 5489,
# scoring
scoring_function = bob.learn.em.linear_scoring
scoring_function = bob.learn.em.linear_scoring,
n_threads=None,
):
"""Initializes the local UBM-GMM tool chain with the given file selector object"""
......@@ -79,6 +81,8 @@ class GMM (Algorithm):
self.rng = bob.core.random.mt19937(self.init_seed)
self.responsibility_threshold = responsibility_threshold
self.scoring_function = scoring_function
self.n_threads = n_threads
self.pool = None
self.ubm = None
self.kmeans_trainer = bob.learn.em.KMeansTrainer()
......@@ -101,6 +105,8 @@ class GMM (Algorithm):
def train_ubm(self, array):
logger.debug(" .... Training with %d feature vectors", array.shape[0])
if self.n_threads is not None:
self.pool = ThreadPool(self.n_threads)
# Computes input size
input_size = array.shape[1]
......@@ -113,9 +119,12 @@ class GMM (Algorithm):
# Trains using the KMeansTrainer
logger.info(" -> Training K-Means")
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, rng=self.rng)
bob.learn.em.train(
self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array)
means = kmeans.means
......@@ -128,10 +137,12 @@ class GMM (Algorithm):
# Trains the GMM
logger.info(" -> Training GMM")
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, rng=self.rng)
bob.learn.em.train(
self.ubm_trainer, self.ubm, array, self.gmm_training_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
def save_ubm(self, projector_file):
"""Save projector to file"""
......@@ -140,7 +151,6 @@ class GMM (Algorithm):
hdf5 = projector_file if isinstance(projector_file, bob.io.base.HDF5File) else bob.io.base.HDF5File(projector_file, 'w')
self.ubm.save(hdf5)
def train_projector(self, train_features, projector_file):
"""Computes the Universal Background Model from the training ("world") data"""
[self._check_feature(feature) for feature in train_features]
......@@ -154,7 +164,6 @@ class GMM (Algorithm):
self.save_ubm(projector_file)
#######################################################
############## GMM training using UBM #################
......@@ -204,7 +213,10 @@ class GMM (Algorithm):
gmm = bob.learn.em.GMMMachine(self.ubm)
gmm.set_variance_thresholds(self.variance_threshold)
bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, rng=self.rng)
bob.learn.em.train(
self.enroll_trainer, gmm, array, self.gmm_enroll_iterations,
self.training_threshold, rng=self.rng, pool=self.pool,
)
return gmm
def enroll(self, feature_arrays):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment