From dbe31f7601cfbe1f175f65064a003c809872bf2a Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Wed, 29 May 2024 11:36:57 +0200
Subject: [PATCH] fix(kmeans): initialization depends on the chunks.

Rechunk the data array to keep the same results of k_init.
---
 src/bob/learn/em/kmeans.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/bob/learn/em/kmeans.py b/src/bob/learn/em/kmeans.py
index bc2664f..3ee0e95 100644
--- a/src/bob/learn/em/kmeans.py
+++ b/src/bob/learn/em/kmeans.py
@@ -307,11 +307,11 @@ class KMeansMachine(BaseEstimator):
         logger.debug(f"Initializing k-means means with '{self.init_method}'.")
         # k_init requires da.Array as input.
         logger.debug("Transform k-means data to dask array")
-        data = da.array(data)
-        data.rechunk(1, data.shape[-1])  # Prevents issue with large arrays.
+        init_data = da.array(data)
+        init_data = init_data.rechunk({0: data.shape[0], -1: data.shape[-1]})
         logger.debug("Get k-means centroids")
         self.centroids_ = k_init(
-            X=data,
+            X=init_data,
             n_clusters=self.n_clusters,
             init=self.init_method,
             random_state=self.random_state,
-- 
GitLab