diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 38908cc7feee4ab3fa9d6f28e853534cd0c59357..e4d617205bd6d0c25dfb5c3358c6563891b5f23a 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -51,12 +51,14 @@ def log_weighted_likelihood(data, machine):
         The weighted log likelihood of each sample of each Gaussian.
     """
     # Compute the likelihood for each data point on each Gaussian
-    n_gaussians, n_samples = len(machine.means), len(data)
-    z = np.empty(shape=(n_gaussians, n_samples), like=data)
+    n_gaussians = len(machine.means)
+    z = []
     for i in range(n_gaussians):
-        z[i] = np.sum(
+        temp = np.sum(
             (data - machine.means[i]) ** 2 / machine.variances[i], axis=-1
         )
+        z.append(temp)
+    z = np.vstack(z)
     ll = -0.5 * (machine.g_norms[:, None] + z)
     log_weighted_likelihoods = machine.log_weights[:, None] + ll
     return log_weighted_likelihoods
@@ -727,6 +729,7 @@ class GMMMachine(BaseEstimator):
 
             # Set the GMM machine's gaussians with the results of k-means
             self.means = copy.deepcopy(kmeans_machine.centroids_)
+            # TODO: remove this compute
             self.variances, self.weights = dask.compute(variances, weights)
             logger.debug("Done.")
 
diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py
index b5b6094988109f32fecd9084e7baa9ee5883ab8d..b48fa2812cd2df6f5bed665f2f115cbcb76982a1 100644
--- a/bob/learn/em/kmeans.py
+++ b/bob/learn/em/kmeans.py
@@ -37,7 +37,10 @@ def get_centroids_distance(x: np.ndarray, means: np.ndarray) -> np.ndarray:
     """
     x = np.atleast_2d(x)
     if isinstance(x, da.Array):
-        return np.sum((means[:, None] - x[None, :]) ** 2, axis=-1)
+        distances = []
+        for i in range(means.shape[0]):
+            distances.append(np.sum((means[i] - x) ** 2, axis=-1))
+        return da.vstack(distances)
     else:
         return scipy.spatial.distance.cdist(means, x, metric="sqeuclidean")
 
@@ -251,25 +254,27 @@ class KMeansMachine(BaseEstimator):
                 Weight (proportion of quantity of data point) of each cluster.
         """
         _, data = check_and_persist_dask_input(data)
-        n_clusters, n_features = self.n_clusters, data.shape[1]
+
+        # TODO: parallelize this like e_step
+        # Accumulate
         dist = get_centroids_distance(data, self.centroids_)
         closest_centroid_indices = get_closest_centroid_index(dist)
+
+        means_sum, variances_sum = [], []
+        for i in range(self.n_clusters):
+            cluster_data = data[closest_centroid_indices == i]
+            means_sum.append(np.sum(cluster_data, axis=0))
+            variances_sum.append(np.sum(cluster_data ** 2, axis=0))
+
+        means_sum, variances_sum = np.vstack(means_sum), np.vstack(
+            variances_sum
+        )
+
+        # Reduce (similar to m_step)
         weights_count = np.bincount(
-            closest_centroid_indices, minlength=n_clusters
+            closest_centroid_indices, minlength=self.n_clusters
         )
         weights = weights_count / weights_count.sum()
-
-        # Accumulate
-        means_sum = np.zeros((n_clusters, n_features), like=data)
-        variances_sum = np.zeros((n_clusters, n_features), like=data)
-        for i in range(n_clusters):
-            means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0)
-        for i in range(n_clusters):
-            variances_sum[i] = np.sum(
-                data[closest_centroid_indices == i] ** 2, axis=0
-            )
-
-        # Reduce
         means = means_sum / weights_count[:, None]
         variances = (variances_sum / weights_count[:, None]) - (means ** 2)
 
@@ -336,7 +341,9 @@ class KMeansMachine(BaseEstimator):
                 convergence_value = abs(
                     (distance_previous - distance) / distance_previous
                 )
-                logger.debug(f"Convergence value = {convergence_value}")
+                logger.debug(
+                    f"Convergence value = {convergence_value} and threshold is {self.convergence_threshold}"
+                )
 
                 # Terminates if converged (and threshold is set)
                 if (