diff --git a/src/bob/learn/em/kmeans.py b/src/bob/learn/em/kmeans.py index bc2664fd2eb679a8fb277408ffa0df218c371f07..3ee0e9570e3b3e618be705ff0e8b8973579cfefe 100644 --- a/src/bob/learn/em/kmeans.py +++ b/src/bob/learn/em/kmeans.py @@ -307,11 +307,11 @@ class KMeansMachine(BaseEstimator): logger.debug(f"Initializing k-means means with '{self.init_method}'.") # k_init requires da.Array as input. logger.debug("Transform k-means data to dask array") - data = da.array(data) - data.rechunk(1, data.shape[-1]) # Prevents issue with large arrays. + init_data = da.array(data) + init_data = init_data.rechunk({0: data.shape[0], -1: data.shape[-1]}) logger.debug("Get k-means centroids") self.centroids_ = k_init( - X=data, + X=init_data, n_clusters=self.n_clusters, init=self.init_method, random_state=self.random_state,