diff --git a/bob/learn/em/k_means.py b/bob/learn/em/k_means.py
index 52f025676e481b51714a6c5c3aca62a145e6ab4b..329cc7cc47798d70eb0880147e8101b3f83a1ae9 100644
--- a/bob/learn/em/k_means.py
+++ b/bob/learn/em/k_means.py
@@ -74,10 +74,10 @@ def e_step(data, means):
     zeroeth_order_statistics = np.bincount(
         closest_k_indices, minlength=n_clusters
     )
-    first_order_statistics = np.sum(
-        np.eye(n_clusters)[closest_k_indices][:, :, None] * data[:, None],
-        axis=0,
-    )
+    # Compute first_order_statistics in a memory efficient way
+    first_order_statistics = np.zeros((n_clusters, data.shape[1]))
+    for i in range(n_clusters):
+        first_order_statistics[i] = np.sum(data[closest_k_indices == i], axis=0)
     min_distance = np.min(distances, axis=0)
     average_min_distance = min_distance.mean()
     return (
@@ -241,28 +241,23 @@ class KMeansMachine(BaseEstimator):
             weights: ndarray of shape (n_clusters, )
                 Weight (proportion of quantity of data point) of each cluster.
         """
-        n_cluster = self.n_clusters
+        n_clusters = self.n_clusters
         dist = get_centroids_distance(data, self.centroids_)
         closest_centroid_indices = get_closest_centroid_index(dist)
         weights_count = np.bincount(
-            closest_centroid_indices, minlength=n_cluster
+            closest_centroid_indices, minlength=n_clusters
         )
         weights = weights_count / weights_count.sum()
 
-        # FIX for `too many indices for array` error if using `np.eye(n_cluster)` alone:
-        dask_compatible_eye = np.eye(n_cluster) * np.array(1, like=data)
-
         # Accumulate
-        means_sum = np.sum(
-            dask_compatible_eye[closest_centroid_indices][:, :, None]
-            * data[:, None],
-            axis=0,
-        )
-        variances_sum = np.sum(
-            dask_compatible_eye[closest_centroid_indices][:, :, None]
-            * (data[:, None] ** 2),
-            axis=0,
-        )
+        means_sum = np.zeros((n_clusters, data.shape[1]))
+        for i in range(n_clusters):
+            means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0)
+        variances_sum = np.zeros((n_clusters, data.shape[1]))
+        for i in range(n_clusters):
+            variances_sum[i] = np.sum(
+                data[closest_centroid_indices == i] ** 2, axis=0
+            )
 
         # Reduce
         means = means_sum / weights_count[:, None]