diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1515754e0bef52e755e793e05c3400ecf532c4c7..1daa2300fc671ed82d211795f4abe5f019b8d7fe 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -2,12 +2,12 @@
 # See https://pre-commit.com/hooks.html for more hooks
 repos:
   - repo: https://github.com/timothycrosley/isort
-    rev: 5.9.3
+    rev: 5.10.1
     hooks:
       - id: isort
         args: [--settings-path, "pyproject.toml"]
   - repo: https://github.com/psf/black
-    rev: 21.7b0
+    rev: 22.3.0
     hooks:
       - id: black
   - repo: https://gitlab.com/pycqa/flake8
@@ -15,7 +15,7 @@ repos:
     hooks:
       - id: flake8
   - repo: https://github.com/pre-commit/pre-commit-hooks
-    rev: v4.0.1
+    rev: v4.1.0
     hooks:
       - id: check-ast
       - id: check-case-conflict
diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index 38908cc7feee4ab3fa9d6f28e853534cd0c59357..2d5b0b9340e69dda2f9f585ebc9c16c83d006d64 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -51,12 +51,14 @@ def log_weighted_likelihood(data, machine):
         The weighted log likelihood of each sample of each Gaussian.
     """
     # Compute the likelihood for each data point on each Gaussian
-    n_gaussians, n_samples = len(machine.means), len(data)
-    z = np.empty(shape=(n_gaussians, n_samples), like=data)
+    n_gaussians = len(machine.means)
+    z = []
     for i in range(n_gaussians):
-        z[i] = np.sum(
+        temp = np.sum(
             (data - machine.means[i]) ** 2 / machine.variances[i], axis=-1
         )
+        z.append(temp)
+    z = np.vstack(z)
     ll = -0.5 * (machine.g_norms[:, None] + z)
     log_weighted_likelihoods = machine.log_weights[:, None] + ll
     return log_weighted_likelihoods
@@ -717,17 +719,15 @@ class GMMMachine(BaseEstimator):
             )
             kmeans_machine = kmeans_machine.fit(data)
 
+            # Set the GMM machine's gaussians with the results of k-means
+            self.means = copy.deepcopy(kmeans_machine.centroids_)
             logger.debug(
                 "Estimating the variance and weights of each gaussian from kmeans."
             )
             (
-                variances,
-                weights,
+                self.variances,
+                self.weights,
             ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data)
-
-            # Set the GMM machine's gaussians with the results of k-means
-            self.means = copy.deepcopy(kmeans_machine.centroids_)
-            self.variances, self.weights = dask.compute(variances, weights)
             logger.debug("Done.")
 
     def log_weighted_likelihood(
@@ -962,13 +962,10 @@ def map_gmm_m_step(
         # n_threshold[statistics.n > mean_var_update_threshold] = statistics.n[
         #     statistics.n > mean_var_update_threshold
         # ]
-        new_means = (
-            np.multiply(
-                alpha[:, None],
-                (statistics.sum_px / n_threshold[:, None]),
-            )
-            + np.multiply((1 - alpha[:, None]), machine.ubm.means)
-        )
+        new_means = np.multiply(
+            alpha[:, None],
+            (statistics.sum_px / n_threshold[:, None]),
+        ) + np.multiply((1 - alpha[:, None]), machine.ubm.means)
         machine.means = np.where(
             statistics.n[:, None] < mean_var_update_threshold,
             machine.ubm.means,
diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py
index b5b6094988109f32fecd9084e7baa9ee5883ab8d..8a17af323e3574d9e17bd991f26a6756f290cc3b 100644
--- a/bob/learn/em/kmeans.py
+++ b/bob/learn/em/kmeans.py
@@ -37,13 +37,27 @@ def get_centroids_distance(x: np.ndarray, means: np.ndarray) -> np.ndarray:
     """
     x = np.atleast_2d(x)
     if isinstance(x, da.Array):
-        return np.sum((means[:, None] - x[None, :]) ** 2, axis=-1)
+        distances = []
+        for i in range(means.shape[0]):
+            distances.append(np.sum((means[i] - x) ** 2, axis=-1))
+        return da.vstack(distances)
     else:
         return scipy.spatial.distance.cdist(means, x, metric="sqeuclidean")
 
 
 def get_closest_centroid_index(centroids_dist: np.ndarray) -> np.ndarray:
-    """Returns the index of the closest cluster mean to x."""
+    """Returns the index of the closest cluster mean to x.
+
+    Parameters
+    ----------
+    centroids_dist: ndarray of shape (n_clusters, n_samples)
+        The squared Euclidian distance (or distances) to each cluster mean.
+
+    Returns
+    -------
+    closest_centroid_indices: ndarray of shape (n_samples,)
+        The index of the closest cluster mean to x.
+    """
     return np.argmin(centroids_dist, axis=0)
 
 
@@ -120,6 +134,43 @@ def m_step(stats, n_samples):
     return means, average_min_distance
 
 
+def accumulate_indices_means_vars(data, means):
+    """Accumulates statistics needed to compute weights and variances of the clusters."""
+    n_clusters, n_features = len(means), data.shape[1]
+    dist = get_centroids_distance(data, means)
+    closest_centroid_indices = get_closest_centroid_index(dist)
+    # the means_sum and variances_sum must be initialized with zero here since
+    # they get accumulated in the next function
+    means_sum = np.zeros((n_clusters, n_features), like=data)
+    variances_sum = np.zeros((n_clusters, n_features), like=data)
+    for i in range(n_clusters):
+        means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0)
+    for i in range(n_clusters):
+        variances_sum[i] = np.sum(
+            data[closest_centroid_indices == i] ** 2, axis=0
+        )
+    return closest_centroid_indices, means_sum, variances_sum
+
+
+def reduce_indices_means_vars(stats):
+    """Computes weights and variances of the clusters given the statistics."""
+    closest_centroid_indices = [s[0] for s in stats]
+    means_sum = [s[1] for s in stats]
+    variances_sum = [s[2] for s in stats]
+
+    closest_centroid_indices = np.concatenate(closest_centroid_indices, axis=0)
+    means_sum = np.sum(means_sum, axis=0)
+    variances_sum = np.sum(variances_sum, axis=0)
+
+    n_clusters = len(means_sum)
+    weights_count = np.bincount(closest_centroid_indices, minlength=n_clusters)
+    weights = weights_count / weights_count.sum()
+    means = means_sum / weights_count[:, None]
+    variances = (variances_sum / weights_count[:, None]) - (means**2)
+
+    return variances, weights
+
+
 def check_and_persist_dask_input(data):
     # check if input is a dask array. If so, persist and rebalance data
     input_is_dask = False
@@ -250,28 +301,24 @@ class KMeansMachine(BaseEstimator):
             weights: ndarray of shape (n_clusters, )
                 Weight (proportion of quantity of data point) of each cluster.
         """
-        _, data = check_and_persist_dask_input(data)
-        n_clusters, n_features = self.n_clusters, data.shape[1]
-        dist = get_centroids_distance(data, self.centroids_)
-        closest_centroid_indices = get_closest_centroid_index(dist)
-        weights_count = np.bincount(
-            closest_centroid_indices, minlength=n_clusters
-        )
-        weights = weights_count / weights_count.sum()
-
-        # Accumulate
-        means_sum = np.zeros((n_clusters, n_features), like=data)
-        variances_sum = np.zeros((n_clusters, n_features), like=data)
-        for i in range(n_clusters):
-            means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0)
-        for i in range(n_clusters):
-            variances_sum[i] = np.sum(
-                data[closest_centroid_indices == i] ** 2, axis=0
-            )
+        input_is_dask, data = check_and_persist_dask_input(data)
+        data = array_to_delayed_list(data, input_is_dask)
 
-        # Reduce
-        means = means_sum / weights_count[:, None]
-        variances = (variances_sum / weights_count[:, None]) - (means ** 2)
+        if input_is_dask:
+            stats = [
+                dask.delayed(accumulate_indices_means_vars)(
+                    xx, means=self.centroids_
+                )
+                for xx in data
+            ]
+            variances, weights = dask.compute(
+                dask.delayed(reduce_indices_means_vars)(stats)
+            )[0]
+        else:
+            # Accumulate
+            stats = accumulate_indices_means_vars(data, self.centroids_)
+            # Reduce
+            variances, weights = reduce_indices_means_vars([stats])
 
         return variances, weights
 
@@ -336,7 +383,9 @@ class KMeansMachine(BaseEstimator):
                 convergence_value = abs(
                     (distance_previous - distance) / distance_previous
                 )
-                logger.debug(f"Convergence value = {convergence_value}")
+                logger.debug(
+                    f"Convergence value = {convergence_value} and threshold is {self.convergence_threshold}"
+                )
 
                 # Terminates if converged (and threshold is set)
                 if (
diff --git a/doc/conf.py b/doc/conf.py
index dd19c99be128f833ccb7c093637a6547e7e14c94..59964034e8fac4f87cb9144f31c4ff36fc5ce89b 100755
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -73,10 +73,10 @@ source_suffix = ".rst"
 master_doc = "index"
 
 # General information about the project.
-project = u"bob.learn.em"
+project = "bob.learn.em"
 import time
 
-copyright = u"%s, Idiap Research Institute" % time.strftime("%Y")
+copyright = "%s, Idiap Research Institute" % time.strftime("%Y")
 
 # Grab the setup entry
 distribution = pkg_resources.require(project)[0]
@@ -126,8 +126,8 @@ pygments_style = "sphinx"
 
 # Some variables which are useful for generated material
 project_variable = project.replace(".", "_")
-short_description = u"Core utilities required on all Bob modules"
-owner = [u"Idiap Research Institute"]
+short_description = "Core utilities required on all Bob modules"
+owner = ["Idiap Research Institute"]
 
 
 # -- Options for HTML output ---------------------------------------------------
@@ -209,7 +209,7 @@ html_favicon = "img/favicon.ico"
 # html_file_suffix = None
 
 # Output file base name for HTML help builder.
-htmlhelp_basename = project_variable + u"_doc"
+htmlhelp_basename = project_variable + "_doc"
 
 
 # -- Post configuration --------------------------------------------------------