Commit 3bc86678 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix a bug where rng is ignored in GMM training

parent e6a49266
Pipeline #18406 failed with stage
in 18 minutes and 32 seconds
......@@ -112,7 +112,7 @@ class GMM (Algorithm):
# Trains using the KMeansTrainer
logger.info(" -> Training K-Means")
bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, self.rng)
bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, rng=self.rng)
variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array)
means = kmeans.means
......@@ -125,7 +125,7 @@ class GMM (Algorithm):
# Trains the GMM
logger.info(" -> Training GMM")
bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, self.rng)
bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, rng=self.rng)
def save_ubm(self, projector_file):
......@@ -199,7 +199,7 @@ class GMM (Algorithm):
gmm = bob.learn.em.GMMMachine(self.ubm)
gmm.set_variance_thresholds(self.variance_threshold)
bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, self.rng)
bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, rng=self.rng)
return gmm
def enroll(self, feature_arrays):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment