diff --git a/src/bob/learn/em/kmeans.py b/src/bob/learn/em/kmeans.py
index 8f968eb61457c840f04d6101f08b14f2a6aae5a7..bc2664fd2eb679a8fb277408ffa0df218c371f07 100644
--- a/src/bob/learn/em/kmeans.py
+++ b/src/bob/learn/em/kmeans.py
@@ -308,6 +308,7 @@ class KMeansMachine(BaseEstimator):
         # k_init requires da.Array as input.
         logger.debug("Transform k-means data to dask array")
         data = da.array(data)
+        data.rechunk(1, data.shape[-1])  # Prevents issue with large arrays.
         logger.debug("Get k-means centroids")
         self.centroids_ = k_init(
             X=data,