Skip to content
Snippets Groups Projects
Commit c397fa83 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Add a parameter to set kmean's init iterations

parent 56bc342f
Branches
No related tags found
1 merge request!28Adapt to `bob.learn.em` API changes
......@@ -13,6 +13,7 @@ This adds the notions of models, probes, enrollment, and scores to GMM.
import copy
import logging
from typing import Union
from typing import Callable
import dask
......@@ -49,6 +50,7 @@ class GMM(BioAlgorithm, BaseEstimator):
number_of_gaussians: int,
# parameters of UBM training
kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means
kmeans_init_iterations: Union[int,None] = None, # Maximum number of iterations for K-Means init
ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training
training_threshold: float = 5e-4, # Threshold to end the ML training
variance_threshold: float = 5e-4, # Minimum value that a variance can reach
......@@ -75,6 +77,9 @@ class GMM(BioAlgorithm, BaseEstimator):
The number of Gaussians used in the UBM and the models.
kmeans_training_iterations
Number of e-m iterations to train k-means initializing the UBM.
kmeans_init_iterations
Number of iterations used for setting the k-means initial centroids.
if None, will use the same as kmeans_training_iterations.
ubm_training_iterations
Number of e-m iterations for training the UBM.
training_threshold
......@@ -110,6 +115,7 @@ class GMM(BioAlgorithm, BaseEstimator):
# Copy parameters
self.number_of_gaussians = number_of_gaussians
self.kmeans_training_iterations = kmeans_training_iterations
self.kmeans_init_iterations = kmeans_init_iterations or kmeans_training_iterations
self.ubm_training_iterations = ubm_training_iterations
self.training_threshold = training_threshold
self.variance_threshold = variance_threshold
......@@ -199,6 +205,13 @@ class GMM(BioAlgorithm, BaseEstimator):
update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold,
k_means_trainer= KMeansMachine(
n_clusters=self.number_of_gaussians,
init_method="k-means||",
max_iter=self.kmeans_training_iterations,
init_max_iter=self.kmeans_init_iterations,
convergence_threshold=self.training_threshold,
)
)
gmm.fit(array)
return gmm
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment