From e72e3598cd5a03ef297c898d7b63f881f1164b19 Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Mon, 29 Nov 2021 12:03:29 +0100
Subject: [PATCH] Allow skipping Reynolds adaptation for MAP GMM

---
 bob/learn/em/mixture/gmm.py   | 16 ++++++++++++++++
 bob/learn/em/test/test_gmm.py |  3 ++-
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py
index 4370d5b..fdab112 100644
--- a/bob/learn/em/mixture/gmm.py
+++ b/bob/learn/em/mixture/gmm.py
@@ -279,6 +279,8 @@ class GMMMachine(BaseEstimator):
         update_variances: bool = False,
         update_weights: bool = False,
         mean_var_update_threshold: float = EPSILON,
+        alpha: float = 0.5,
+        relevance_factor: Union[None, float] = 4,
     ):
         """
         Parameters
@@ -308,6 +310,14 @@ class GMMMachine(BaseEstimator):
             Update the Gaussians variances at every m step.
         update_weights
             Update the GMM weights at every m step.
+        mean_var_update_threshold:
+            Threshold value used when updating the means and variances.
+        alpha:
+            Ratio for MAP adaptation. Used when `trainer == "map"` and
+            `relevance_factor is None`)
+        relevance_factor:
+            Factor for the computation of alpha with Reyolds adaptation. (Used when
+            `trainer == "map"`)
         """
 
         self.n_gaussians = n_gaussians
@@ -346,6 +356,8 @@ class GMMMachine(BaseEstimator):
             )
         if weights is not None:
             self.weights = weights
+        self.alpha = alpha
+        self.relevance_factor = relevance_factor 
 
     @property
     def weights(self):
@@ -683,6 +695,9 @@ class GMMMachine(BaseEstimator):
             update_variances=self.update_variances,
             update_weights=self.update_weights,
             mean_var_update_threshold=self.mean_var_update_threshold,
+            reynolds_adaptation=self.relevance_factor is not None,
+            alpha=self.alpha,
+            relevance_factor=self.relevance_factor,
             **kwargs,
         )
 
@@ -763,6 +778,7 @@ def ml_gmm_m_step(
     update_variances=False,
     update_weights=False,
     mean_var_update_threshold=EPSILON,
+    **kwargs,
 ):
     """Updates a gmm machine parameter according to the e-step statistics."""
     logger.debug("ML GMM Trainer m-step")
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 89e96bf..8e39e22 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -752,7 +752,8 @@ def test_gmm_MAP_3():
         update_means=True,
         update_variances=False,
         update_weights=False,
-        mean_var_update_threshold=accuracy
+        mean_var_update_threshold=accuracy,
+        relevance_factor=None,
     )
     gmm.variance_thresholds = threshold
 
-- 
GitLab