diff --git a/bob/learn/em/k_means.py b/bob/learn/em/k_means.py
index 5d7c95e16f9d81af971ce61fb1d0a4cdc8731903..52f025676e481b51714a6c5c3aca62a145e6ab4b 100644
--- a/bob/learn/em/k_means.py
+++ b/bob/learn/em/k_means.py
@@ -8,10 +8,9 @@ from typing import Union
 
 import dask
 import dask.array as da
-import dask.bag
 import dask.delayed
-import distributed
 import numpy as np
+import scipy.spatial.distance
 
 from dask_ml.cluster.k_means import k_init
 from sklearn.base import BaseEstimator
@@ -26,15 +25,21 @@ def get_centroids_distance(x: np.ndarray, means: np.ndarray) -> np.ndarray:
 
     Parameters
     ----------
-    x: ndarray of shape (n_features,) or (n_samples, n_features)
-        One data point, or a series of data points.
+    x: ndarray of shape (n_samples, n_features)
+        A series of data points.
+    means: ndarray of shape (n_clusters, n_features)
+        The centroids.
 
     Returns
     -------
-    distances: ndarray of shape (n_clusters,) or (n_clusters, n_samples)
+    distances: ndarray of shape (n_clusters, n_samples)
         For each cluster, the squared Euclidian distance (or distances) to x.
     """
-    return np.sum((means[:, None] - x[None, :]) ** 2, axis=-1)
+    x = np.atleast_2d(x)
+    if isinstance(x, da.Array):
+        return np.sum((means[:, None] - x[None, :]) ** 2, axis=-1)
+    else:
+        return scipy.spatial.distance.cdist(means, x, metric="sqeuclidean")
 
 
 def get_closest_centroid_index(centroids_dist: np.ndarray) -> np.ndarray:
@@ -115,6 +120,23 @@ def m_step(stats, n_samples):
     return means, average_min_distance
 
 
+def check_and_persist_dask_input(data):
+    # check if input is a dask array. If so, persist and rebalance data
+    input_is_dask = False
+    if isinstance(data, da.Array):
+        data: da.Array = data.persist()
+        input_is_dask = True
+    return input_is_dask
+
+
+def array_to_delayed_list(data, input_is_dask):
+    # If input is a dask array, convert to delayed chunks
+    if input_is_dask:
+        data = data.to_delayed().ravel().tolist()
+        logger.debug(f"Got {len(data)} chunks.")
+    return data
+
+
 class KMeansMachine(BaseEstimator):
     """Stores the k-means clusters parameters (centroid of each cluster).
 
@@ -265,27 +287,14 @@ class KMeansMachine(BaseEstimator):
     def fit(self, X, y=None):
         """Fits this machine on data samples."""
 
-        # check if input is a dask array. If so, persist and rebalance data
-        client, input_is_dask = None, False
-        if isinstance(X, da.Array):
-            X: da.Array = X.persist()
-            input_is_dask = True
-
-            try:
-                client = distributed.Client.current()
-                client.rebalance()
-            except ValueError:
-                pass
+        input_is_dask = check_and_persist_dask_input(X)
 
         logger.debug("Initializing trainer.")
         self.initialize(data=X)
 
         n_samples = len(X)
 
-        # If input is a dask array, convert to delayed chunks
-        if input_is_dask:
-            X = X.to_delayed()
-            logger.debug(f"Got {len(X)} chunks.")
+        X = array_to_delayed_list(X, input_is_dask)
 
         logger.info("Training k-means.")
         distance = np.inf
@@ -301,8 +310,7 @@ class KMeansMachine(BaseEstimator):
             # compute the e-m steps
             if input_is_dask:
                 stats = [
-                    dask.delayed(e_step)(xx, means=self.centroids_)
-                    for xx in X.ravel().tolist()
+                    dask.delayed(e_step)(xx, means=self.centroids_) for xx in X
                 ]
                 self.centroids_, self.average_min_distance = dask.compute(
                     dask.delayed(m_step)(stats, n_samples)
@@ -313,17 +321,6 @@ class KMeansMachine(BaseEstimator):
                     stats, n_samples
                 )
 
-            # scatter centroids to all workers for efficiency
-            if client is not None:
-                logger.debug("Broadcasting centroids to all workers.")
-                future = client.scatter(self.centroids_, broadcast=True)
-                self.centroids_ = da.from_delayed(
-                    future,
-                    shape=self.centroids_.shape,
-                    dtype=self.centroids_.dtype,
-                )
-                client.rebalance()
-
             distance = self.average_min_distance
 
             logger.debug(
@@ -345,6 +342,7 @@ class KMeansMachine(BaseEstimator):
                         "Reached convergence threshold. Training stopped."
                     )
                     break
+
         else:
             logger.info(
                 "Reached maximum step. Training stopped without convergence."
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index 539d20a950c132087fde9ef60573717fd44ed272..1400b469e909973e3a7911c7e7439a55b7d87a47 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -14,8 +14,9 @@ import copy
 
 import dask.array as da
 import numpy as np
+import scipy.spatial.distance
 
-from bob.learn.em import KMeansMachine
+from bob.learn.em import KMeansMachine, k_means
 
 
 def to_numpy(*args):
@@ -174,3 +175,18 @@ def test_kmeans_parameters():
             [0.99479125, 0.99665564, 0.97689017],
         ]
         np.testing.assert_almost_equal(centroids, expected, decimal=7)
+
+
+def test_get_centroids_distance():
+    np.random.seed(0)
+    n_features = 60
+    n_samples = 240_000
+    n_clusters = 256
+    data = np.random.normal(loc=1, size=(n_samples, n_features))
+    means = np.random.normal(loc=-1, size=(n_clusters, n_features))
+    oracle = scipy.spatial.distance.cdist(means, data, metric="sqeuclidean")
+    for transform in (to_numpy,):
+        data, means = transform(data, means)
+        dist = k_means.get_centroids_distance(data, means)
+        np.testing.assert_allclose(dist, oracle)
+        assert type(data) is type(dist), (type(data), type(dist))