From c397fa8336c5fad0ffc0214489baedb8f5073162 Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Wed, 23 Feb 2022 15:15:17 +0100 Subject: [PATCH] Add a parameter to set kmean's init iterations --- bob/bio/gmm/algorithm/GMM.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index bc6c143..e33fc22 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -13,6 +13,7 @@ This adds the notions of models, probes, enrollment, and scores to GMM. import copy import logging +from typing import Union from typing import Callable import dask @@ -49,6 +50,7 @@ class GMM(BioAlgorithm, BaseEstimator): number_of_gaussians: int, # parameters of UBM training kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means + kmeans_init_iterations: Union[int,None] = None, # Maximum number of iterations for K-Means init ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training training_threshold: float = 5e-4, # Threshold to end the ML training variance_threshold: float = 5e-4, # Minimum value that a variance can reach @@ -75,6 +77,9 @@ class GMM(BioAlgorithm, BaseEstimator): The number of Gaussians used in the UBM and the models. kmeans_training_iterations Number of e-m iterations to train k-means initializing the UBM. + kmeans_init_iterations + Number of iterations used for setting the k-means initial centroids. + if None, will use the same as kmeans_training_iterations. ubm_training_iterations Number of e-m iterations for training the UBM. training_threshold @@ -110,6 +115,7 @@ class GMM(BioAlgorithm, BaseEstimator): # Copy parameters self.number_of_gaussians = number_of_gaussians self.kmeans_training_iterations = kmeans_training_iterations + self.kmeans_init_iterations = kmeans_init_iterations or kmeans_training_iterations self.ubm_training_iterations = ubm_training_iterations self.training_threshold = training_threshold self.variance_threshold = variance_threshold @@ -199,6 +205,13 @@ class GMM(BioAlgorithm, BaseEstimator): update_variances=self.enroll_update_variances, update_weights=self.enroll_update_weights, mean_var_update_threshold=self.variance_threshold, + k_means_trainer= KMeansMachine( + n_clusters=self.number_of_gaussians, + init_method="k-means||", + max_iter=self.kmeans_training_iterations, + init_max_iter=self.kmeans_init_iterations, + convergence_threshold=self.training_threshold, + ) ) gmm.fit(array) return gmm -- GitLab