diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 78ab69fe2de3f1506529c220c89c5bc57f1fa805..ac850dd88623271d556ea70764837ed3d080c7e7 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -672,6 +672,9 @@ class GMMMachine(BaseEstimator):
             )
             kmeans_machine = kmeans_machine.fit(data)
 
+            logger.debug(
+                "Estimating the variance and weights of each gaussian from kmeans."
+            )
             (
                 variances,
                 weights,
@@ -680,6 +683,7 @@ class GMMMachine(BaseEstimator):
             # Set the GMM machine's gaussians with the results of k-means
             self.means = copy.deepcopy(kmeans_machine.centroids_)
             self.variances, self.weights = dask.compute(variances, weights)
+            logger.debug("Done.")
 
     def log_weighted_likelihood(
         self,
@@ -833,7 +837,7 @@ class GMMMachine(BaseEstimator):
     def fit(self, X, y=None):
         """Trains the GMM on data until convergence or maximum step is reached."""
 
-        input_is_dask = check_and_persist_dask_input(X)
+        input_is_dask, X = check_and_persist_dask_input(X)
 
         if self._means is None:
             self.initialize_gaussians(X)
@@ -912,7 +916,9 @@ class GMMMachine(BaseEstimator):
                     (average_output_previous - average_output)
                     / average_output_previous
                 )
-                logger.debug(f"convergence val = {convergence_value}")
+                logger.debug(
+                    f"convergence val = {convergence_value} and threshold = {self.convergence_threshold}"
+                )
 
                 # Terminates if converged (and likelihood computation is set)
                 if (
diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py
index 329cc7cc47798d70eb0880147e8101b3f83a1ae9..b5b6094988109f32fecd9084e7baa9ee5883ab8d 100644
--- a/bob/learn/em/kmeans.py
+++ b/bob/learn/em/kmeans.py
@@ -126,7 +126,16 @@ def check_and_persist_dask_input(data):
     if isinstance(data, da.Array):
         data: da.Array = data.persist()
         input_is_dask = True
-    return input_is_dask
+        # if there is a dask distributed client, rebalance data
+        try:
+            client = dask.distributed.Client.current()
+            client.rebalance()
+        except ValueError:
+            pass
+
+    else:
+        data = np.asarray(data)
+    return input_is_dask, data
 
 
 def array_to_delayed_list(data, input_is_dask):
@@ -241,7 +250,8 @@ class KMeansMachine(BaseEstimator):
             weights: ndarray of shape (n_clusters, )
                 Weight (proportion of quantity of data point) of each cluster.
         """
-        n_clusters = self.n_clusters
+        _, data = check_and_persist_dask_input(data)
+        n_clusters, n_features = self.n_clusters, data.shape[1]
         dist = get_centroids_distance(data, self.centroids_)
         closest_centroid_indices = get_closest_centroid_index(dist)
         weights_count = np.bincount(
@@ -250,10 +260,10 @@ class KMeansMachine(BaseEstimator):
         weights = weights_count / weights_count.sum()
 
         # Accumulate
-        means_sum = np.zeros((n_clusters, data.shape[1]))
+        means_sum = np.zeros((n_clusters, n_features), like=data)
+        variances_sum = np.zeros((n_clusters, n_features), like=data)
         for i in range(n_clusters):
             means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0)
-        variances_sum = np.zeros((n_clusters, data.shape[1]))
         for i in range(n_clusters):
             variances_sum[i] = np.sum(
                 data[closest_centroid_indices == i] ** 2, axis=0
@@ -282,7 +292,7 @@ class KMeansMachine(BaseEstimator):
     def fit(self, X, y=None):
         """Fits this machine on data samples."""
 
-        input_is_dask = check_and_persist_dask_input(X)
+        input_is_dask, X = check_and_persist_dask_input(X)
 
         logger.debug("Initializing trainer.")
         self.initialize(data=X)