Skip to content
Snippets Groups Projects
Commit 3bc86678 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix a bug where rng is ignored in GMM training

parent e6a49266
Branches
Tags
1 merge request!18Fix a bug where rng is ignored in GMM training
Pipeline #
Showing
with 4003 additions and 4003 deletions
...@@ -112,7 +112,7 @@ class GMM (Algorithm): ...@@ -112,7 +112,7 @@ class GMM (Algorithm):
# Trains using the KMeansTrainer # Trains using the KMeansTrainer
logger.info(" -> Training K-Means") logger.info(" -> Training K-Means")
bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, self.rng) bob.learn.em.train(self.kmeans_trainer, kmeans, array, self.kmeans_training_iterations, self.training_threshold, rng=self.rng)
variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array) variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array)
means = kmeans.means means = kmeans.means
...@@ -125,7 +125,7 @@ class GMM (Algorithm): ...@@ -125,7 +125,7 @@ class GMM (Algorithm):
# Trains the GMM # Trains the GMM
logger.info(" -> Training GMM") logger.info(" -> Training GMM")
bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, self.rng) bob.learn.em.train(self.ubm_trainer, self.ubm, array, self.gmm_training_iterations, self.training_threshold, rng=self.rng)
def save_ubm(self, projector_file): def save_ubm(self, projector_file):
...@@ -199,7 +199,7 @@ class GMM (Algorithm): ...@@ -199,7 +199,7 @@ class GMM (Algorithm):
gmm = bob.learn.em.GMMMachine(self.ubm) gmm = bob.learn.em.GMMMachine(self.ubm)
gmm.set_variance_thresholds(self.variance_threshold) gmm.set_variance_thresholds(self.variance_threshold)
bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, self.rng) bob.learn.em.train(self.enroll_trainer, gmm, array, self.gmm_enroll_iterations, self.training_threshold, rng=self.rng)
return gmm return gmm
def enroll(self, feature_arrays): def enroll(self, feature_arrays):
......
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment