From e72e3598cd5a03ef297c898d7b63f881f1164b19 Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Mon, 29 Nov 2021 12:03:29 +0100 Subject: [PATCH] Allow skipping Reynolds adaptation for MAP GMM --- bob/learn/em/mixture/gmm.py | 16 ++++++++++++++++ bob/learn/em/test/test_gmm.py | 3 ++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py index 4370d5b..fdab112 100644 --- a/bob/learn/em/mixture/gmm.py +++ b/bob/learn/em/mixture/gmm.py @@ -279,6 +279,8 @@ class GMMMachine(BaseEstimator): update_variances: bool = False, update_weights: bool = False, mean_var_update_threshold: float = EPSILON, + alpha: float = 0.5, + relevance_factor: Union[None, float] = 4, ): """ Parameters @@ -308,6 +310,14 @@ class GMMMachine(BaseEstimator): Update the Gaussians variances at every m step. update_weights Update the GMM weights at every m step. + mean_var_update_threshold: + Threshold value used when updating the means and variances. + alpha: + Ratio for MAP adaptation. Used when `trainer == "map"` and + `relevance_factor is None`) + relevance_factor: + Factor for the computation of alpha with Reyolds adaptation. (Used when + `trainer == "map"`) """ self.n_gaussians = n_gaussians @@ -346,6 +356,8 @@ class GMMMachine(BaseEstimator): ) if weights is not None: self.weights = weights + self.alpha = alpha + self.relevance_factor = relevance_factor @property def weights(self): @@ -683,6 +695,9 @@ class GMMMachine(BaseEstimator): update_variances=self.update_variances, update_weights=self.update_weights, mean_var_update_threshold=self.mean_var_update_threshold, + reynolds_adaptation=self.relevance_factor is not None, + alpha=self.alpha, + relevance_factor=self.relevance_factor, **kwargs, ) @@ -763,6 +778,7 @@ def ml_gmm_m_step( update_variances=False, update_weights=False, mean_var_update_threshold=EPSILON, + **kwargs, ): """Updates a gmm machine parameter according to the e-step statistics.""" logger.debug("ML GMM Trainer m-step") diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 89e96bf..8e39e22 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -752,7 +752,8 @@ def test_gmm_MAP_3(): update_means=True, update_variances=False, update_weights=False, - mean_var_update_threshold=accuracy + mean_var_update_threshold=accuracy, + relevance_factor=None, ) gmm.variance_thresholds = threshold -- GitLab