diff --git a/src/bob/learn/em/kmeans.py b/src/bob/learn/em/kmeans.py
index bc2664fd2eb679a8fb277408ffa0df218c371f07..3ee0e9570e3b3e618be705ff0e8b8973579cfefe 100644
--- a/src/bob/learn/em/kmeans.py
+++ b/src/bob/learn/em/kmeans.py
@@ -307,11 +307,11 @@ class KMeansMachine(BaseEstimator):
         logger.debug(f"Initializing k-means means with '{self.init_method}'.")
         # k_init requires da.Array as input.
         logger.debug("Transform k-means data to dask array")
-        data = da.array(data)
-        data.rechunk(1, data.shape[-1])  # Prevents issue with large arrays.
+        init_data = da.array(data)
+        init_data = init_data.rechunk({0: data.shape[0], -1: data.shape[-1]})
         logger.debug("Get k-means centroids")
         self.centroids_ = k_init(
-            X=data,
+            X=init_data,
             n_clusters=self.n_clusters,
             init=self.init_method,
             random_state=self.random_state,