diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py index 65c8137ec7f5e4d600ea2e64a1de54006b50189c..0be5a1772bc19184cd8534312f8901192f96c575 100644 --- a/bob/learn/em/mixture/gmm.py +++ b/bob/learn/em/mixture/gmm.py @@ -862,10 +862,20 @@ def map_gmm_m_step( # Equation 12 of Reynolds et al., "Speaker Verification Using Adapted # Gaussian Mixture Models", Digital Signal Processing, 2000 if update_means: + # Apply threshold to prevent divide by zero below + n_threshold = np.where( + statistics.n < mean_var_update_threshold, + mean_var_update_threshold, + statistics.n, + ) + # n_threshold = np.full(statistics.n.shape, fill_value=mean_var_update_threshold) + # n_threshold[statistics.n > mean_var_update_threshold] = statistics.n[ + # statistics.n > mean_var_update_threshold + # ] new_means = ( np.multiply( alpha[:, None], - (statistics.sum_px / statistics.n[:, None]), + (statistics.sum_px / n_threshold[:, None]), ) + np.multiply((1 - alpha[:, None]), machine.ubm.means) )