Skip to content
Snippets Groups Projects
Commit c397fa83 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Add a parameter to set kmean's init iterations

parent 56bc342f
Branches
No related tags found
1 merge request!28Adapt to `bob.learn.em` API changes
...@@ -13,6 +13,7 @@ This adds the notions of models, probes, enrollment, and scores to GMM. ...@@ -13,6 +13,7 @@ This adds the notions of models, probes, enrollment, and scores to GMM.
import copy import copy
import logging import logging
from typing import Union
from typing import Callable from typing import Callable
import dask import dask
...@@ -49,6 +50,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -49,6 +50,7 @@ class GMM(BioAlgorithm, BaseEstimator):
number_of_gaussians: int, number_of_gaussians: int,
# parameters of UBM training # parameters of UBM training
kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means
kmeans_init_iterations: Union[int,None] = None, # Maximum number of iterations for K-Means init
ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training
training_threshold: float = 5e-4, # Threshold to end the ML training training_threshold: float = 5e-4, # Threshold to end the ML training
variance_threshold: float = 5e-4, # Minimum value that a variance can reach variance_threshold: float = 5e-4, # Minimum value that a variance can reach
...@@ -75,6 +77,9 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -75,6 +77,9 @@ class GMM(BioAlgorithm, BaseEstimator):
The number of Gaussians used in the UBM and the models. The number of Gaussians used in the UBM and the models.
kmeans_training_iterations kmeans_training_iterations
Number of e-m iterations to train k-means initializing the UBM. Number of e-m iterations to train k-means initializing the UBM.
kmeans_init_iterations
Number of iterations used for setting the k-means initial centroids.
if None, will use the same as kmeans_training_iterations.
ubm_training_iterations ubm_training_iterations
Number of e-m iterations for training the UBM. Number of e-m iterations for training the UBM.
training_threshold training_threshold
...@@ -110,6 +115,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -110,6 +115,7 @@ class GMM(BioAlgorithm, BaseEstimator):
# Copy parameters # Copy parameters
self.number_of_gaussians = number_of_gaussians self.number_of_gaussians = number_of_gaussians
self.kmeans_training_iterations = kmeans_training_iterations self.kmeans_training_iterations = kmeans_training_iterations
self.kmeans_init_iterations = kmeans_init_iterations or kmeans_training_iterations
self.ubm_training_iterations = ubm_training_iterations self.ubm_training_iterations = ubm_training_iterations
self.training_threshold = training_threshold self.training_threshold = training_threshold
self.variance_threshold = variance_threshold self.variance_threshold = variance_threshold
...@@ -199,6 +205,13 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -199,6 +205,13 @@ class GMM(BioAlgorithm, BaseEstimator):
update_variances=self.enroll_update_variances, update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights, update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold, mean_var_update_threshold=self.variance_threshold,
k_means_trainer= KMeansMachine(
n_clusters=self.number_of_gaussians,
init_method="k-means||",
max_iter=self.kmeans_training_iterations,
init_max_iter=self.kmeans_init_iterations,
convergence_threshold=self.training_threshold,
)
) )
gmm.fit(array) gmm.fit(array)
return gmm return gmm
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment