From c397fa8336c5fad0ffc0214489baedb8f5073162 Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Wed, 23 Feb 2022 15:15:17 +0100
Subject: [PATCH] Add a parameter to set kmean's init iterations

---
 bob/bio/gmm/algorithm/GMM.py | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py
index bc6c143..e33fc22 100644
--- a/bob/bio/gmm/algorithm/GMM.py
+++ b/bob/bio/gmm/algorithm/GMM.py
@@ -13,6 +13,7 @@ This adds the notions of models, probes, enrollment, and scores to GMM.
 import copy
 import logging
 
+from typing import Union
 from typing import Callable
 
 import dask
@@ -49,6 +50,7 @@ class GMM(BioAlgorithm, BaseEstimator):
         number_of_gaussians: int,
         # parameters of UBM training
         kmeans_training_iterations: int = 25,  # Maximum number of iterations for K-Means
+        kmeans_init_iterations: Union[int,None] = None,  # Maximum number of iterations for K-Means init
         ubm_training_iterations: int = 25,  # Maximum number of iterations for GMM Training
         training_threshold: float = 5e-4,  # Threshold to end the ML training
         variance_threshold: float = 5e-4,  # Minimum value that a variance can reach
@@ -75,6 +77,9 @@ class GMM(BioAlgorithm, BaseEstimator):
             The number of Gaussians used in the UBM and the models.
         kmeans_training_iterations
             Number of e-m iterations to train k-means initializing the UBM.
+        kmeans_init_iterations
+            Number of iterations used for setting the k-means initial centroids.
+            if None, will use the same as kmeans_training_iterations.
         ubm_training_iterations
             Number of e-m iterations for training the UBM.
         training_threshold
@@ -110,6 +115,7 @@ class GMM(BioAlgorithm, BaseEstimator):
         # Copy parameters
         self.number_of_gaussians = number_of_gaussians
         self.kmeans_training_iterations = kmeans_training_iterations
+        self.kmeans_init_iterations = kmeans_init_iterations or kmeans_training_iterations
         self.ubm_training_iterations = ubm_training_iterations
         self.training_threshold = training_threshold
         self.variance_threshold = variance_threshold
@@ -199,6 +205,13 @@ class GMM(BioAlgorithm, BaseEstimator):
                 update_variances=self.enroll_update_variances,
                 update_weights=self.enroll_update_weights,
                 mean_var_update_threshold=self.variance_threshold,
+                k_means_trainer= KMeansMachine(
+                    n_clusters=self.number_of_gaussians,
+                    init_method="k-means||",
+                    max_iter=self.kmeans_training_iterations,
+                    init_max_iter=self.kmeans_init_iterations,
+                    convergence_threshold=self.training_threshold,
+                )
             )
             gmm.fit(array)
         return gmm
-- 
GitLab