diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py index fdab112a0111ff393db61307b49b2f452ef17e57..5c8c0c076280f0581b1ecd22d94f3f124fecc731 100644 --- a/bob/learn/em/mixture/gmm.py +++ b/bob/learn/em/mixture/gmm.py @@ -667,14 +667,14 @@ class GMMMachine(BaseEstimator): # Count of samples [int] statistics.t += data.shape[0] # Responsibilities [array of shape (n_gaussians,)] - statistics.n += responsibility.sum(axis=-1) + statistics.n = statistics.n + responsibility.sum(axis=-1) # p * x [array of shape (n_gaussians, n_samples, n_features)] px = np.multiply(responsibility[:, :, None], data[None, :, :]) # First order stats [array of shape (n_gaussians, n_features)] - statistics.sum_px += px.sum(axis=1) + statistics.sum_px = statistics.sum_px + px.sum(axis=1) # Second order stats [array of shape (n_gaussians, n_features)] pxx = np.multiply(px[:, :, :], data[None, :, :]) - statistics.sum_pxx += pxx.sum(axis=1) + statistics.sum_pxx = statistics.sum_pxx + pxx.sum(axis=1) return statistics diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 8e39e228bd69626d0fbda514e18a17f2eec7b090..43a9449de845647d932fa086eed4513c5041d017 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -621,7 +621,7 @@ def test_gmm_ML_2(): def test_gmm_ML_parallel(): - """Trains a GMMMachine with ML_GMMTrainer; compares to an old reference""" + """Trains a GMMMachine with ML_GMMTrainer; compares to a reference""" ar = da.array(load_array(resource_filename("bob.learn.em", "data/dataNormalized.hdf5")))