diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py
index 747d45bb5b6410800a5010b8b2c5aee5daec5df8..a4a0450367f23d7daa2653d5512072b2ed76b12b 100644
--- a/bob/bio/gmm/algorithm/GMM.py
+++ b/bob/bio/gmm/algorithm/GMM.py
@@ -5,22 +5,54 @@
 
 import logging
 
-from multiprocessing.pool import ThreadPool
-
 import numpy
 
+from dask_ml.cluster import KMeans as DistributedKMeans
+from sklearn.base import BaseEstimator
+from sklearn.base import TransformerMixin
+
 import bob.core
 import bob.io.base
 import bob.learn.em
 
 from bob.bio.base.algorithm import Algorithm
+from bob.bio.gmm.algorithm.utils import get_variances_and_weights_of_clusters
+
+logger = logging.getLogger(__name__)
+
 
-logger = logging.getLogger("bob.bio.gmm")
+class KMeans(BaseEstimator, TransformerMixin):
+    def transform(self, X, **kwargs):
+        pass
+
+    def fit(self, X, y=None):
+        pass
 
 
 class GMM(Algorithm):
-    """Algorithm for computing Universal Background Models and Gaussian Mixture Models of the features.
-    Features must be normalized to zero mean and unit standard deviation."""
+    """Algorithm for computing Universal Background Models and Gaussian Mixture Models
+    of the features.
+
+    Features must be normalized to zero mean and unit standard deviation.
+
+    Parameters
+    ----------
+    number_of_gaussians: int
+        Number of Gaussian components (number of clusters).
+    kmeans_training_iterations: int
+        Max number of iterations for initialization step with K-Means.
+    kmeans_training_threshold: float
+        Threshold value to stop K-Means training.
+    gmm_training_iterations: int
+        Maximum number of E-M steps for the GMM training.
+    training_threshold: float
+        Convergence threshold below which the E-M algorithm will be stopped.
+    variance_threshold: float
+    update_weights: bool
+    update_means: bool
+    update_variances: bool
+
+    """
 
     def __init__(
         self,
@@ -28,6 +60,7 @@ class GMM(Algorithm):
         number_of_gaussians,
         # parameters of UBM training
         kmeans_training_iterations=25,  # Maximum number of iterations for K-Means
+        kmeans_training_threshold=5e-4,
         gmm_training_iterations=25,  # Maximum number of iterations for ML GMM Training
         training_threshold=5e-4,  # Threshold to end the ML training
         variance_threshold=5e-4,  # Minimum value that a variance can reach
@@ -50,26 +83,15 @@ class GMM(Algorithm):
             self,
             performs_projection=True,
             use_projected_features_for_enrollment=False,
-            number_of_gaussians=number_of_gaussians,
-            kmeans_training_iterations=kmeans_training_iterations,
-            gmm_training_iterations=gmm_training_iterations,
-            training_threshold=training_threshold,
-            variance_threshold=variance_threshold,
-            update_weights=update_weights,
-            update_means=update_means,
-            update_variances=update_variances,
-            relevance_factor=relevance_factor,
-            gmm_enroll_iterations=gmm_enroll_iterations,
-            responsibility_threshold=responsibility_threshold,
-            INIT_SEED=INIT_SEED,
             scoring_function=str(scoring_function),
             multiple_model_scoring=None,
             multiple_probe_scoring="average",
         )
 
         # copy parameters
-        self.gaussians = number_of_gaussians
+        self.number_of_gaussians = number_of_gaussians
         self.kmeans_training_iterations = kmeans_training_iterations
+        self.kmeans_training_threshold = kmeans_training_threshold
         self.gmm_training_iterations = gmm_training_iterations
         self.training_threshold = training_threshold
         self.variance_threshold = variance_threshold
@@ -86,7 +108,14 @@ class GMM(Algorithm):
         self.pool = None
 
         self.ubm = None
-        self.kmeans_trainer = bob.learn.em.KMeansTrainer()
+        self.kmeans_trainer = DistributedKMeans(
+            n_clusters=self.number_of_gaussians,
+            init="k-means||",  # TODO switch to "k-means++" if data fits in memory
+            init_max_iter=self.kmeans_training_iterations,
+            max_iter=self.kmeans_training_iterations,
+            tol=self.kmeans_training_threshold,
+            random_state=self.init_seed,
+        )
         self.ubm_trainer = bob.learn.em.ML_GMMTrainer(
             self.update_means,
             self.update_variances,
@@ -101,7 +130,7 @@ class GMM(Algorithm):
             or feature.ndim != 2
             or feature.dtype != numpy.float64
         ):
-            raise ValueError("The given feature is not appropriate")
+            raise ValueError(f"The given feature is not appropriate: {feature}")
         if self.ubm is not None and feature.shape[1] != self.ubm.shape[1]:
             raise ValueError(
                 "The given feature is expected to have %d elements, but it has %d"
@@ -115,36 +144,29 @@ class GMM(Algorithm):
 
         logger.debug(" .... Training with %d feature vectors", array.shape[0])
         if self.n_threads is not None:
-            self.pool = ThreadPool(self.n_threads)
+            raise ValueError("n_threads is not supported")
+            # self.pool = ThreadPool(self.n_threads)
 
         # Computes input size
         input_size = array.shape[1]
 
         # Creates the machines (KMeans and GMM)
         logger.debug(" .... Creating machines")
-        kmeans = bob.learn.em.KMeansMachine(self.gaussians, input_size)
-        self.ubm = bob.learn.em.GMMMachine(self.gaussians, input_size)
+        # kmeans = bob.learn.em.KMeansMachine(self.number_of_gaussians, input_size)
+        self.ubm = bob.learn.em.GMMMachine(self.number_of_gaussians, input_size)
 
-        # Trains using the KMeansTrainer
         logger.info("  -> Training K-Means")
+        self.kmeans_trainer = self.kmeans_trainer.fit(array)
+        distances_to_means = self.kmeans_trainer.transform(array)
 
-        # Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
-        self.rng = bob.core.random.mt19937(self.init_seed)
-        bob.learn.em.train(
-            self.kmeans_trainer,
-            kmeans,
-            array,
-            self.kmeans_training_iterations,
-            self.training_threshold,
-            rng=self.rng,
-            pool=self.pool,
-        )
+        logger.debug("Compute K-Means variances and weights")
 
-        variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array)
-        means = kmeans.means
+        variances, weights = get_variances_and_weights_of_clusters(
+            array, distances_to_means, self.number_of_gaussians
+        )
 
         # Initializes the GMM
-        self.ubm.means = means
+        self.ubm.means = self.kmeans_trainer.cluster_centers_
         self.ubm.variances = variances
         self.ubm.weights = weights
         self.ubm.set_variance_thresholds(self.variance_threshold)
@@ -215,7 +237,7 @@ class GMM(Algorithm):
             relevance_factor=self.relevance_factor,
             update_means=True,
             update_variances=False,
-            **kwargs
+            **kwargs,
         )
         self.rng = bob.core.random.mt19937(self.init_seed)
 
diff --git a/bob/bio/gmm/algorithm/utils.py b/bob/bio/gmm/algorithm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..013733daa855e436231ca3588735bb5c85460d87
--- /dev/null
+++ b/bob/bio/gmm/algorithm/utils.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python
+
+import logging
+
+import dask.array
+import numpy
+
+logger = logging.getLogger(__name__)
+
+
+def delayed_to_xr_dataset(delayed, meta=None):
+    """Converts one dask.delayed object to a dask.array"""
+    if meta is None:
+        meta = numpy.array(delayed.data.compute())
+        print(meta.shape)
+
+    da = dask.array.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
+    return da, meta
+
+
+def delayed_samples_to_dask_arrays(delayed_samples, meta=None):
+    output = []
+    for ds in delayed_samples:
+        da, meta = delayed_to_xr_dataset(ds, meta)
+        output.append(da)
+    return output, meta
+
+
+def delayeds_to_xr_dataset(delayeds, meta=None):
+    """Converts a set of dask.delayed to a list of dask.array"""
+    output = []
+    for d in delayeds:
+        da, meta = delayed_samples_to_dask_arrays(d, meta)
+        output.extend(da)
+    return output
+
+
+def get_variances_and_weights_of_clusters(
+    data, distances_to_clusters_means, cluster_count
+):
+    """Computes and returns the variances and weights of clustered data.
+
+    Used by GMM to initialize the Mixtures after K-Means.
+
+    Parameters
+    ----------
+    data: 2D dask.array
+        The data to compute the variance of.
+    distances_to_clusters_means: 3D dask.array
+        The distance of each point in data to each cluster's mean.
+    cluster_count: int
+        The number of clusters.
+
+    Returns
+    -------
+    variances: 2D dask.array
+    weights: 1D dask.array
+    """
+    closest_means_indices = distances_to_clusters_means.argmin(axis=1)
+    weights_count = dask.array.bincount(closest_means_indices, minlength=cluster_count)
+    weights = weights_count / weights_count.sum()
+
+    # Accumulate
+    means_sum = dask.array.array(
+        [data[closest_means_indices == i].sum(axis=0) for i in range(cluster_count)]
+    )
+    variances_sum = dask.array.array(
+        [
+            (data[closest_means_indices == i] ** 2).sum(axis=0)
+            for i in range(cluster_count)
+        ]
+    )
+
+    # Reduce
+    means = means_sum / weights_count[:, None]
+    variances = (variances_sum / weights_count[:, None]) - (means ** 2)
+
+    logger.debug(
+        f"get_variances_and_weights_of_clusters: var: {variances.compute()}, weights: {weights.compute()}"
+    )
+    return variances, weights