diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py index 4370d5b01822b8ca50350bd181be41ec2ea41327..fdab112a0111ff393db61307b49b2f452ef17e57 100644 --- a/bob/learn/em/mixture/gmm.py +++ b/bob/learn/em/mixture/gmm.py @@ -279,6 +279,8 @@ class GMMMachine(BaseEstimator): update_variances: bool = False, update_weights: bool = False, mean_var_update_threshold: float = EPSILON, + alpha: float = 0.5, + relevance_factor: Union[None, float] = 4, ): """ Parameters @@ -308,6 +310,14 @@ class GMMMachine(BaseEstimator): Update the Gaussians variances at every m step. update_weights Update the GMM weights at every m step. + mean_var_update_threshold: + Threshold value used when updating the means and variances. + alpha: + Ratio for MAP adaptation. Used when `trainer == "map"` and + `relevance_factor is None`) + relevance_factor: + Factor for the computation of alpha with Reyolds adaptation. (Used when + `trainer == "map"`) """ self.n_gaussians = n_gaussians @@ -346,6 +356,8 @@ class GMMMachine(BaseEstimator): ) if weights is not None: self.weights = weights + self.alpha = alpha + self.relevance_factor = relevance_factor @property def weights(self): @@ -683,6 +695,9 @@ class GMMMachine(BaseEstimator): update_variances=self.update_variances, update_weights=self.update_weights, mean_var_update_threshold=self.mean_var_update_threshold, + reynolds_adaptation=self.relevance_factor is not None, + alpha=self.alpha, + relevance_factor=self.relevance_factor, **kwargs, ) @@ -763,6 +778,7 @@ def ml_gmm_m_step( update_variances=False, update_weights=False, mean_var_update_threshold=EPSILON, + **kwargs, ): """Updates a gmm machine parameter according to the e-step statistics.""" logger.debug("ML GMM Trainer m-step") diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 89e96bfc98a7601c5e431d8ae30345619884ad9b..8e39e228bd69626d0fbda514e18a17f2eec7b090 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -752,7 +752,8 @@ def test_gmm_MAP_3(): update_means=True, update_variances=False, update_weights=False, - mean_var_update_threshold=accuracy + mean_var_update_threshold=accuracy, + relevance_factor=None, ) gmm.variance_thresholds = threshold