Commit bbdf4f38 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'rng_fix' into 'master'

Fix a bug where rng is ignored in GMM training

See merge request !18
parents 30957983 fdb710b5
Pipeline #19084 passed with stages
in 34 minutes and 44 seconds
......@@ -112,7 +112,10 @@ class GMM (Algorithm):
# Trains using the KMeansTrainer
logger.info(" -> Training K-Means")
bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, self.rng)
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, rng=self.rng)
variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array)
means = kmeans.means
......@@ -125,7 +128,9 @@ class GMM (Algorithm):
# Trains the GMM
logger.info(" -> Training GMM")
bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, self.rng)
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, rng=self.rng)
def save_ubm(self, projector_file):
......@@ -199,7 +204,7 @@ class GMM (Algorithm):
gmm = bob.learn.em.GMMMachine(self.ubm)
gmm.set_variance_thresholds(self.variance_threshold)
bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, self.rng)
bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, rng=self.rng)
return gmm
def enroll(self, feature_arrays):
......
......@@ -64,6 +64,8 @@ class ISV (GMM):
logger.info(" -> Training ISV enroller")
self.isvbase = bob.learn.em.ISVBase(self.ubm, self.subspace_dimension_of_u)
# train ISV model
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
bob.learn.em.train(self.isv_trainer, self.isvbase, data, self.isv_training_iterations, rng=self.rng)
......
......@@ -94,6 +94,9 @@ class IVector (GMM):
logger.info(" -> Training IVector enroller")
self.tv = bob.learn.em.IVectorMachine(self.ubm, self.subspace_dimension_of_t, self.variance_threshold)
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
# train IVector model
bob.learn.em.train(self.ivector_trainer, self.tv, training_stats, self.tv_training_iterations, rng=self.rng)
......@@ -125,6 +128,10 @@ class IVector (GMM):
variance_flooring = 1e-5
training_features = [numpy.vstack(client) for client in training_features]
input_dim = training_features[0].shape[1]
# Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
self.rng = bob.core.random.mt19937(self.init_seed)
self.plda_base = bob.learn.em.PLDABase(input_dim, self.plda_dim_F, self.plda_dim_G, variance_flooring)
bob.learn.em.train(self.plda_trainer, self.plda_base, training_features, self.plda_training_iterations, rng=self.rng)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment