diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 436d169fb26a0935e680fe045cd0f6ca5ac652de..88efe81af0e7a5918cfc3f604adf3fc906aec83d 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -119,7 +119,7 @@ def m_step(
         relevance_factor=map_relevance_factor,
     )
     average_output = float(statistics.log_likelihood / statistics.t)
-    return average_output
+    return machine, average_output
 
 
 class GMMStats:
@@ -888,9 +888,11 @@ class GMMMachine(BaseEstimator):
                     )
                     for xx in X
                 ]
-                average_output = dask.compute(
+                new_machine, average_output = dask.compute(
                     dask.delayed(m_step_func)(self, stats)
                 )[0]
+                for attr in ["weights", "means", "variances"]:
+                    setattr(self, attr, getattr(new_machine, attr))
             else:
                 stats = [
                     e_step(
@@ -902,7 +904,7 @@ class GMMMachine(BaseEstimator):
                         log_weights=self.log_weights,
                     )
                 ]
-                average_output = m_step_func(self, stats)
+                _, average_output = m_step_func(self, stats)
 
             logger.debug(f"log likelihood = {average_output}")
             if step > 1:
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 7eda1c63df3a7827ce7631eb6a0c78ea0971f5f7..3ad780a32be84b20d27b2cb14f79001309568e7d 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -16,11 +16,14 @@ from copy import deepcopy
 import dask.array as da
 import numpy as np
 
+from dask.distributed import Client
 from h5py import File as HDF5File
 from pkg_resources import resource_filename
 
 from bob.learn.em import GMMMachine, GMMStats, KMeansMachine
 
+from .test_kmeans import to_dask_array, to_numpy
+
 
 def load_array(filename):
     with HDF5File(filename, "r") as f:
@@ -464,13 +467,22 @@ def test_gmm_kmeans_parallel_init():
     data = np.array(
         [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]]
     )
-    machine = machine.fit(data)
-    expected_means = np.array([[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]])
-    expected_variances = np.array(
-        [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
-    )
-    np.testing.assert_almost_equal(machine.means, expected_means, decimal=3)
-    np.testing.assert_almost_equal(machine.variances, expected_variances)
+    with Client().as_current():
+        for transform in (to_numpy, to_dask_array):
+            data = transform(data)
+            machine = machine.fit(data)
+            expected_means = np.array(
+                [[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]]
+            )
+            expected_variances = np.array(
+                [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
+            )
+            np.testing.assert_almost_equal(
+                machine.means, expected_means, decimal=3
+            )
+            np.testing.assert_almost_equal(
+                machine.variances, expected_variances
+            )
 
 
 def test_likelihood():