diff --git a/bob/learn/em/__init__.py b/bob/learn/em/__init__.py
index 6d67218a8049049a3e44c11076f6f8a61a1e09dc..6c50a0a6c1a429550f2e377665a21b5d13431ce3 100644
--- a/bob/learn/em/__init__.py
+++ b/bob/learn/em/__init__.py
@@ -1,7 +1,7 @@
 import bob.extension
 
 from .gmm import GMMMachine, GMMStats
-from .k_means import KMeansMachine
+from .kmeans import KMeansMachine
 from .linear_scoring import linear_scoring  # noqa: F401
 from .wccn import WCCN
 from .whitening import Whitening
diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 88efe81af0e7a5918cfc3f604adf3fc906aec83d..78ab69fe2de3f1506529c220c89c5bc57f1fa805 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -18,7 +18,7 @@ import numpy as np
 from h5py import File as HDF5File
 from sklearn.base import BaseEstimator
 
-from .k_means import (
+from .kmeans import (
     KMeansMachine,
     array_to_delayed_list,
     check_and_persist_dask_input,
diff --git a/bob/learn/em/k_means.py b/bob/learn/em/kmeans.py
similarity index 94%
rename from bob/learn/em/k_means.py
rename to bob/learn/em/kmeans.py
index 52f025676e481b51714a6c5c3aca62a145e6ab4b..329cc7cc47798d70eb0880147e8101b3f83a1ae9 100644
--- a/bob/learn/em/k_means.py
+++ b/bob/learn/em/kmeans.py
@@ -74,10 +74,10 @@ def e_step(data, means):
     zeroeth_order_statistics = np.bincount(
         closest_k_indices, minlength=n_clusters
     )
-    first_order_statistics = np.sum(
-        np.eye(n_clusters)[closest_k_indices][:, :, None] * data[:, None],
-        axis=0,
-    )
+    # Compute first_order_statistics in a memory efficient way
+    first_order_statistics = np.zeros((n_clusters, data.shape[1]))
+    for i in range(n_clusters):
+        first_order_statistics[i] = np.sum(data[closest_k_indices == i], axis=0)
     min_distance = np.min(distances, axis=0)
     average_min_distance = min_distance.mean()
     return (
@@ -241,28 +241,23 @@ class KMeansMachine(BaseEstimator):
             weights: ndarray of shape (n_clusters, )
                 Weight (proportion of quantity of data point) of each cluster.
         """
-        n_cluster = self.n_clusters
+        n_clusters = self.n_clusters
         dist = get_centroids_distance(data, self.centroids_)
         closest_centroid_indices = get_closest_centroid_index(dist)
         weights_count = np.bincount(
-            closest_centroid_indices, minlength=n_cluster
+            closest_centroid_indices, minlength=n_clusters
         )
         weights = weights_count / weights_count.sum()
 
-        # FIX for `too many indices for array` error if using `np.eye(n_cluster)` alone:
-        dask_compatible_eye = np.eye(n_cluster) * np.array(1, like=data)
-
         # Accumulate
-        means_sum = np.sum(
-            dask_compatible_eye[closest_centroid_indices][:, :, None]
-            * data[:, None],
-            axis=0,
-        )
-        variances_sum = np.sum(
-            dask_compatible_eye[closest_centroid_indices][:, :, None]
-            * (data[:, None] ** 2),
-            axis=0,
-        )
+        means_sum = np.zeros((n_clusters, data.shape[1]))
+        for i in range(n_clusters):
+            means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0)
+        variances_sum = np.zeros((n_clusters, data.shape[1]))
+        for i in range(n_clusters):
+            variances_sum[i] = np.sum(
+                data[closest_centroid_indices == i] ** 2, axis=0
+            )
 
         # Reduce
         means = means_sum / weights_count[:, None]
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 3ad780a32be84b20d27b2cb14f79001309568e7d..b62da1abce52ea7cb96e1aa26f55ccee35a2d373 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -8,6 +8,7 @@
 """Tests the GMM machine and the GMMStats container
 """
 
+import contextlib
 import os
 import tempfile
 
@@ -32,6 +33,16 @@ def load_array(filename):
     return np.squeeze(array)
 
 
+@contextlib.contextmanager
+def multiprocess_dask_client():
+    try:
+        client = Client()
+        with client.as_current():
+            yield client
+    finally:
+        client.close()
+
+
 def test_GMMStats():
     # Test a GMMStats
     # Initializes a GMMStats
@@ -467,7 +478,7 @@ def test_gmm_kmeans_parallel_init():
     data = np.array(
         [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]]
     )
-    with Client().as_current():
+    with multiprocess_dask_client():
         for transform in (to_numpy, to_dask_array):
             data = transform(data)
             machine = machine.fit(data)
@@ -778,7 +789,7 @@ def test_gmm_ML_1():
 
 
 def test_gmm_ML_2():
-    """Trains a GMMMachine with ML_GMMTrainer; compares to a reference"""
+    # Trains a GMMMachine with ML_GMMTrainer; compares to a reference
     ar = load_array(
         resource_filename("bob.learn.em", "data/dataNormalized.hdf5")
     )
@@ -829,7 +840,7 @@ def test_gmm_ML_2():
 
 
 def test_gmm_MAP_1():
-    """Train a GMMMachine with MAP_GMMTrainer"""
+    # Train a GMMMachine with MAP_GMMTrainer
     ar = load_array(
         resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5")
     )
@@ -875,7 +886,7 @@ def test_gmm_MAP_1():
 
 
 def test_gmm_MAP_2():
-    """Train a GMMMachine with MAP_GMMTrainer and compare with matlab reference"""
+    # Train a GMMMachine with MAP_GMMTrainer and compare with matlab reference
 
     data = load_array(resource_filename("bob.learn.em", "data/data.hdf5"))
     data = data.reshape((1, -1))  # make a 2D array out of it
@@ -915,7 +926,7 @@ def test_gmm_MAP_2():
 
 
 def test_gmm_MAP_3():
-    """Train a GMMMachine with MAP_GMMTrainer; compares to old reference"""
+    # Train a GMMMachine with MAP_GMMTrainer; compares to old reference
     ar = load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5"))
 
     # Initialize GMMMachine
@@ -976,7 +987,7 @@ def test_gmm_MAP_3():
 
 
 def test_gmm_test():
-    """Tests a GMMMachine by computing scores against a model and comparing to a reference"""
+    # Tests a GMMMachine by computing scores against a model and comparing to a reference
 
     ar = load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5"))
 
@@ -1006,7 +1017,7 @@ def test_gmm_test():
 
 
 def test_gmm_ML_dask():
-    """Trains a GMMMachine with dask array data; compares to a reference"""
+    # Trains a GMMMachine with dask array data; compares to a reference
 
     ar = da.array(
         load_array(
@@ -1061,7 +1072,7 @@ def test_gmm_ML_dask():
 
 
 def test_gmm_MAP_dask():
-    """Test a GMMMachine for MAP with a dask array as data."""
+    # Test a GMMMachine for MAP with a dask array as data.
     ar = da.array(
         load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5"))
     )
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index 1400b469e909973e3a7911c7e7439a55b7d87a47..8edc205ab9ecf2f6f15a8c9e4213f57cb219f804 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -16,7 +16,7 @@ import dask.array as da
 import numpy as np
 import scipy.spatial.distance
 
-from bob.learn.em import KMeansMachine, k_means
+from bob.learn.em import KMeansMachine, kmeans
 
 
 def to_numpy(*args):
@@ -187,6 +187,6 @@ def test_get_centroids_distance():
     oracle = scipy.spatial.distance.cdist(means, data, metric="sqeuclidean")
     for transform in (to_numpy,):
         data, means = transform(data, means)
-        dist = k_means.get_centroids_distance(data, means)
+        dist = kmeans.get_centroids_distance(data, means)
         np.testing.assert_allclose(dist, oracle)
         assert type(data) is type(dist), (type(data), type(dist))