diff --git a/bob/learn/em/cluster/__init__.py b/bob/learn/em/cluster/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7a216cf26e2fa7f7e11c2996c2805bbae8f432
--- /dev/null
+++ b/bob/learn/em/cluster/__init__.py
@@ -0,0 +1,2 @@
+from .k_means import KMeansMachine
+from .k_means import KMeansTrainer
diff --git a/bob/learn/em/cluster/k_means.py b/bob/learn/em/cluster/k_means.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcf105040abd93cd82e6c1b5e59ecbd5da179e93
--- /dev/null
+++ b/bob/learn/em/cluster/k_means.py
@@ -0,0 +1,349 @@
+#!/usr/bin/env python
+# @author: Yannick Dayer <yannick.dayer@idiap.ch>
+# @date: Tue 27 Jul 2021 11:04:10 UTC+02
+
+import logging
+from typing import Union
+from typing import Tuple
+
+import numpy as np
+import dask.array as da
+from dask_ml.cluster.k_means import k_init
+from sklearn.base import BaseEstimator
+
+logger = logging.getLogger(__name__)
+
+
+class KMeansMachine(BaseEstimator):
+    """Stores the k-means clusters parameters (centroid of each cluster).
+
+    Allows the clustering of data with the ``fit`` method.
+
+    Parameters
+    ----------
+    n_clusters: int
+        The number of represented clusters.
+
+    Attributes
+    ----------
+    centroids_: ndarray of shape (n_clusters, n_features)
+        The current clusters centroids. Available after fitting.
+
+    Example
+    -------
+    >>> data = dask.array.array([[0,-1,0],[-1,1,1],[3,2,1],[2,2,1],[1,0,2]])
+    >>> machine = KMeansMachine(2).fit(data)
+    >>> machine.centroids_.compute()
+    ... array([[0. , 0. , 1. ],
+    ...        [2.5, 2. , 1. ]])
+    """
+
+    def __init__(
+        self,
+        n_clusters: int,
+        convergence_threshold: float = 1e-5,
+        random_state: Union[int, da.random.RandomState] = 0,
+    ) -> None:
+        if n_clusters < 1:
+            raise ValueError("The Number of cluster should be greater thant 0.")
+        self.n_clusters = n_clusters
+        self.random_state = random_state
+        self.convergence_threshold = convergence_threshold
+
+    def get_centroids_distance(self, x: da.Array) -> da.Array:
+        """Returns the distance values between x and each cluster's centroid.
+
+        The returned values are squared Euclidean distances.
+
+        Parameters
+        ----------
+        x: ndarray of shape (n_features,) or (n_samples, n_features)
+            One data point, or a series of data points.
+
+        Returns
+        -------
+        distances: ndarray of shape (n_clusters,) or (n_clusters, n_samples)
+            For each cluster, the squared Euclidian distance (or distances) to x.
+        """
+        return da.sum((self.centroids_[:, None] - x[None, :]) ** 2, axis=-1)
+
+    def get_closest_centroid(self, x: da.Array) -> Tuple[int, float]:
+        """Returns the closest mean's index and squared Euclidian distance to x."""
+        dists = self.get_centroids_distance(x)
+        min_id = da.argmin(dists, axis=0)
+        min_dist = dists[min_id]
+        return min_id, min_dist
+
+    def get_closest_centroid_index(self, x: da.Array) -> da.Array:
+        """Returns the index of the closest cluster mean to x."""
+        return da.argmin(self.get_centroids_distance(x), axis=0)
+
+    def get_min_distance(self, x: da.Array) -> da.Array:
+        """Returns the smallest distance between that point and the clusters centroids.
+
+        For each point in x, the minimum distance to each cluster's mean is returned.
+
+        The returned values are squared Euclidean distances.
+        """
+        return da.min(self.get_centroids_distance(x), axis=0)
+
+    def __eq__(self, obj) -> bool:
+        if hasattr(self, "centroids_") and hasattr(obj, "centroids_"):
+            return da.allclose(self.centroids_, obj.centroids_, rtol=0, atol=0)
+        else:
+            raise ValueError("centroids_ was not set. You should call 'fit' first.")
+
+    def is_similar_to(self, obj, r_epsilon=1e-05, a_epsilon=1e-08) -> bool:
+        if hasattr(self, "centroids_") and hasattr(obj, "centroids_"):
+            return da.allclose(
+                self.centroids_, obj.centroids_, rtol=r_epsilon, atol=a_epsilon
+            )
+        else:
+            raise ValueError("centroids_ was not set. You should call 'fit' first.")
+
+    def get_variances_and_weights_for_each_cluster(self, data: da.Array):
+        """Returns the clusters variance and weight for data clustered by the machine.
+
+        For each cluster, finds the subset of the samples that is closest to that
+        centroid, and calculates:
+        1) the variance of that subset (the cluster variance)
+        2) the proportion of samples represented by that subset (the cluster weight)
+
+        Parameters
+        ----------
+        data: dask.array
+            The data to compute the variance of.
+
+        Returns
+        -------
+        Tuple of arrays:
+            variances: ndarray of shape (n_clusters, n_features)
+                For each cluster, the variance in each dimension of the data.
+            weights: ndarray of shape (n_clusters, )
+                Weight (proportion of quantity of data point) of each cluster.
+        """
+        n_cluster = self.n_clusters
+        closest_centroid_indices = self.get_closest_centroid_index(data)
+        weights_count = da.bincount(closest_centroid_indices, minlength=n_cluster)
+        weights = weights_count / weights_count.sum()
+
+        # Accumulate
+        means_sum = da.sum(
+            da.eye(n_cluster)[closest_centroid_indices][:, :, None] * data[:, None],
+            axis=0,
+        )
+        variances_sum = da.sum(
+            da.eye(n_cluster)[closest_centroid_indices][:, :, None]
+            * (data[:, None] ** 2),
+            axis=0,
+        )
+
+        # Reduce
+        means = means_sum / weights_count[:, None]
+        variances = (variances_sum / weights_count[:, None]) - (means ** 2)
+
+        return variances, weights
+
+    def fit(self, X, y=None, trainer=None):
+        """Fits this machine with a k-means trainer.
+
+        The default trainer (when None is given) uses k-means|| for init, then uses e-m
+        until it converges or the limit number of iterations is reached.
+        """
+        if trainer is None:
+            logger.info("Using default k-means trainer.")
+            trainer = KMeansTrainer(init_method="k-means||", random_state=self.random_state)
+
+        logger.debug(f"Initializing trainer.")
+        trainer.initialize(
+            machine=self,
+            data=X,
+        )
+
+        logger.info("Training k-means.")
+        distance = np.inf
+        for step in range(trainer.max_iter):
+            logger.info(f"Iteration {step:3d}/{trainer.max_iter}")
+            distance_previous = distance
+            trainer.e_step(machine=self, data=X)
+            trainer.m_step(machine=self, data=X)
+
+            distance = trainer.compute_likelihood(self)
+
+            # logger.info(f"Average squared Euclidean distance = {distance.compute()}")
+
+            if step > 0:
+                convergence_value = abs(
+                    (distance_previous - distance) / distance_previous
+                )
+                # logger.info(f"Convergence value = {convergence_value.compute()}")
+
+                # Terminates if converged (and threshold is set)
+                if (
+                    self.convergence_threshold is not None
+                    and convergence_value <= self.convergence_threshold
+                ):
+                    logger.info("Stopping Training: Convergence threshold met.")
+                    return self
+        logger.info("Stopping Training: Iterations limit reached.")
+        return self
+
+    def partial_fit(self, X, y=None, trainer=None):
+        if trainer is None:
+            logger.info("Using default k-means trainer.")
+            trainer = KMeansTrainer(init_method="k-means||")
+        if not hasattr(self, "means_"):
+            logger.debug(f"First call of 'partial_fit'. Initializing trainer.")
+            trainer.initialize(
+                machine=self,
+                data=X,
+            )
+        for step in range(trainer.max_iter):
+            logger.info(f"Iteration = {step:3d}/{trainer.max_iter}")
+            distance_previous = distance
+            trainer.e_step(machine=self, data=X)
+            trainer.m_step(machine=self, data=X)
+
+            distance = trainer.compute_likelihood(self)
+
+            logger.info(f"Average squared Euclidean distance = {distance}")
+
+            convergence_value = abs((distance_previous - distance) / distance_previous)
+            logger.info(f"Convergence value = {convergence_value}")
+
+            # Terminates if converged (and threshold is set)
+            if (
+                self.convergence_threshold is not None
+                and convergence_value <= self.convergence_threshold
+            ):
+                logger.info("Stopping Training: Convergence threshold met.")
+                return self
+        logger.info("Stopping Training: Iterations limit reached.")
+        return self
+
+    def transform(self, X):
+        """Returns all the distances between the data and each cluster's mean.
+
+        Parameters
+        ----------
+        X: ndarray of shape (n_samples, n_features)
+            Series of data points.
+
+        Returns
+        -------
+        distances: ndarray of shape (n_clusters, n_samples)
+            For each mean, for each point, the squared Euclidian distance between them.
+        """
+        return self.get_centroids_distance(X)
+
+    def predict(self, X):
+        """Returns the labels of the closest cluster centroid to the data.
+
+        Parameters
+        ----------
+        X: ndarray of shape (n_samples, n_features)
+            Series of data points.
+
+        Returns
+        -------
+        indices: ndarray of shape (n_samples)
+            The indices of the closest cluster for each data point.
+        """
+        return self.get_closest_centroid_index(X)
+
+
+class KMeansTrainer:
+    """E-M Trainer that applies k-means on a KMeansMachine.
+
+    This trainer works in two phases:
+        - An initialization (setting the initial values of the centroids)
+        - An e-m loop reducing the total distance between the data points and their
+          closest centroid.
+
+    The initialization can use an iterative process to find the best set of
+    coordinates, use random starting points, or take specified coordinates. The
+    ``init_method`` parameter specifies which of these behavior is considered.
+
+    Parameters
+    ----------
+    init_method:
+        One of: "random", "k-means++", or "k-means||", or an array with the wanted
+        starting values of the centroids.
+    init_max_iter:
+        The maximum number of iterations for the initialization part.
+    random_state:
+        A seed or RandomState used for the initialization part.
+    max_iter:
+        The maximum number of iterations for the e-m part.
+    """
+
+    def __init__(
+        self,
+        init_method: Union[str, da.Array] = "k-means||",
+        init_max_iter: Union[int, None] = None,
+        random_state: Union[int, da.random.RandomState] = 0,
+        max_iter: int = 20,
+    ):
+        self.init_method = init_method
+        self.average_min_distance = None
+        self.zeroeth_order_statistics = None
+        self.first_order_statistics = None
+        self.max_iter = max_iter
+        self.init_max_iter = init_max_iter
+        self.random_state = random_state
+
+    def initialize(
+        self,
+        machine: KMeansMachine,
+        data: da.Array,
+    ):
+        """Assigns the means to an initial value using a specified method or randomly."""
+        logger.debug(f"Initializing k-means means with '{self.init_method}'.")
+        data = da.array(data)
+        machine.centroids_ = k_init(
+            X=data,
+            n_clusters=machine.n_clusters,
+            init=self.init_method,
+            random_state=self.random_state,
+            max_iter=self.init_max_iter,
+        )
+
+    def e_step(self, machine: KMeansMachine, data: da.Array):
+        data = da.array(data)
+        closest_centroid_indices = machine.get_closest_centroid_index(data)
+        # Number of data points in each cluster
+        self.zeroeth_order_statistics = da.bincount(
+            closest_centroid_indices, minlength=machine.n_clusters
+        )
+        # Sum of data points coordinates in each cluster
+        self.first_order_statistics = da.sum(
+            da.eye(machine.n_clusters)[closest_centroid_indices][:, :, None]
+            * data[:, None],
+            axis=0,
+        )
+        self.average_min_distance = machine.get_min_distance(data).mean()
+
+    def m_step(self, machine: KMeansMachine, data: da.Array):
+        machine.centroids_ = (
+            self.first_order_statistics / self.zeroeth_order_statistics[:, None]
+        ).persist()
+
+    def compute_likelihood(self, machine: KMeansMachine):
+        if self.average_min_distance is None:
+            logger.error("compute_likelihood should be called after e_step.")
+            return 0
+        return self.average_min_distance
+
+    def copy(self):
+        new_trainer = KMeansTrainer()
+        new_trainer.average_min_distance = self.average_min_distance
+        new_trainer.zeroeth_order_statistics = self.zeroeth_order_statistics
+        new_trainer.first_order_statistics = self.first_order_statistics
+        return new_trainer
+
+    def reset_accumulators(self, machine: KMeansMachine):
+        self.average_min_distance = 0
+        self.zeroeth_order_statistics = da.zeros((machine.n_clusters,), dtype="float64")
+        self.first_order_statistics = da.zeros(
+            (machine.n_clusters, machine.n_dims), dtype="float64"
+        )
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index be80ba27a154cd549e048b0504b4c84c06df6a38..0cadba354a4e40f288c124b92f7d953830a189dd 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -8,89 +8,76 @@
 """Tests the KMeans machine
 """
 
-import os
-import numpy
-import tempfile
+import numpy as np
 
-import bob.io.base
-from bob.learn.em import KMeansMachine
+from bob.learn.em.cluster import KMeansMachine
+from bob.learn.em.cluster import KMeansTrainer
+
+import dask.array as da
 
-def equals(x, y, epsilon):
-  return (abs(x - y) < epsilon)
 
 def test_KMeansMachine():
-  # Test a KMeansMachine
-
-  means = numpy.array([[3, 70, 0], [4, 72, 0]], 'float64')
-  mean  = numpy.array([3,70,1], 'float64')
-
-  # Initializes a KMeansMachine
-  km = KMeansMachine(2,3)
-  km.means = means
-  assert km.shape == (2,3)
-
-  # Sets and gets
-  assert (km.means == means).all()
-  assert (km.get_mean(0) == means[0,:]).all()  
-  assert (km.get_mean(1) == means[1,:]).all()
-  km.set_mean(0, mean)
-  assert (km.get_mean(0) == mean).all()
-
-  # Distance and closest mean
-  eps = 1e-10
-
-  assert equals( km.get_distance_from_mean(mean, 0), 0, eps)
-  assert equals( km.get_distance_from_mean(mean, 1), 6, eps)  
-  
-  (index, dist) = km.get_closest_mean(mean)
-  
-  assert index == 0
-  assert equals( dist, 0, eps)
-  assert equals( km.get_min_distance(mean), 0, eps)
-
-  # Loads and saves
-  filename = str(tempfile.mkstemp(".hdf5")[1])
-  km.save(bob.io.base.HDF5File(filename, 'w'))
-  km_loaded = KMeansMachine(bob.io.base.HDF5File(filename))
-  assert km == km_loaded
-
-  # Resize
-  km.resize(4,5)
-  assert km.shape == (4,5)
-
-  # Copy constructor and comparison operators
-  km.resize(2,3)
-  km2 = KMeansMachine(km)
-  assert km2 == km
-  assert (km2 != km) is False
-  assert km2.is_similar_to(km)
-  means2 = numpy.array([[3, 70, 0], [4, 72, 2]], 'float64')
-  km2.means = means2
-  assert (km2 == km) is False
-  assert km2 != km
-  assert (km2.is_similar_to(km)) is False
-
-  # Clean-up
-  os.unlink(filename)
-  
-  
-def test_KMeansMachine2():
-  kmeans             = bob.learn.em.KMeansMachine(2,2)
-  kmeans.means       = numpy.array([[1.2,1.3],[0.2,-0.3]])
-
-  data               = numpy.array([
-                                  [1.,1],
-                                  [1.2, 3],
-                                  [0,0],
-                                  [0.3,0.2],
-                                  [0.2,0]
-                                 ])
-  variances, weights = kmeans.get_variances_and_weights_for_each_cluster(data)
-
-  variances_result = numpy.array([[ 0.01,1.],
-                                  [ 0.01555556, 0.00888889]])
-  weights_result = numpy.array([ 0.4, 0.6])
-  
-  assert equals(weights_result,weights, 1e-3).all()
-  assert equals(variances_result,variances,1e-3).all()
- 
+    # Test a KMeansMachine
+
+    means = np.array([[3, 70, 0], [4, 72, 0]], "float64")
+    mean = np.array([3, 70, 1], "float64")
+
+    # Initializes a KMeansMachine
+    km = KMeansMachine(2)
+    km.centroids_ = means
+
+    # Distance and closest mean
+    np.testing.assert_almost_equal(km.transform(mean)[0], 1)
+    np.testing.assert_almost_equal(km.transform(mean)[1], 6)
+
+    (index, dist) = km.get_closest_centroid(mean)
+
+    assert index == 0, index
+    np.testing.assert_almost_equal(dist, 1.0)
+    np.testing.assert_almost_equal(km.get_min_distance(mean), 1)
+
+
+def test_KMeansMachine_var_and_weight():
+    kmeans = KMeansMachine(2)
+    kmeans.centroids_ = np.array([[1.2, 1.3], [0.2, -0.3]])
+
+    data = np.array([[1.0, 1], [1.2, 3], [0, 0], [0.3, 0.2], [0.2, 0]])
+    variances, weights = kmeans.get_variances_and_weights_for_each_cluster(data)
+
+    variances_result = np.array([[0.01, 1.0], [0.01555556, 0.00888889]])
+    weights_result = np.array([0.4, 0.6])
+
+    np.testing.assert_almost_equal(variances, variances_result)
+    np.testing.assert_almost_equal(weights, weights_result)
+
+
+def test_kmeans_fit():
+    da.random.seed(0)
+    data1 = da.random.normal(loc=1, size=(2000, 3))
+    data2 = da.random.normal(loc=-1, size=(2000, 3))
+    data = da.concatenate([data1, data2], axis=0)
+    machine = KMeansMachine(2, random_state=0).fit(data)
+    expected = [[1.00426431, 1.00359693, 1.05996704], [-0.99262315, -1.05226141, -1.00525245]]
+    np.testing.assert_almost_equal(machine.centroids_, expected)
+
+
+def test_kmeans_fit_init_pp():
+    da.random.seed(0)
+    data1 = da.random.normal(loc=1, size=(2000, 3))
+    data2 = da.random.normal(loc=-1, size=(2000, 3))
+    data = da.concatenate([data1, data2], axis=0)
+    trainer = KMeansTrainer(init_method="k-means++", random_state=0)
+    machine = KMeansMachine(2).fit(data, trainer=trainer)
+    expected = [[-0.99262315, -1.05226141, -1.00525245], [1.00426431, 1.00359693, 1.05996704]]
+    np.testing.assert_almost_equal(machine.centroids_, expected)
+
+
+def test_kmeans_fit_init_random():
+    da.random.seed(0)
+    data1 = da.random.normal(loc=1, size=(2000, 3))
+    data2 = da.random.normal(loc=-1, size=(2000, 3))
+    data = da.concatenate([data1, data2], axis=0)
+    trainer = KMeansTrainer(init_method="random", random_state=0)
+    machine = KMeansMachine(2).fit(data, trainer=trainer)
+    expected = [[-0.99433738, -1.05561588, -1.01236246], [0.99800688, 0.99873325, 1.05879539]]
+    np.testing.assert_almost_equal(machine.centroids_, expected)
diff --git a/bob/learn/em/test/test_kmeans_trainer.py b/bob/learn/em/test/test_kmeans_trainer.py
deleted file mode 100644
index 537df0e9abc7b52bbec5fc19311de7bf24c1c446..0000000000000000000000000000000000000000
--- a/bob/learn/em/test/test_kmeans_trainer.py
+++ /dev/null
@@ -1,228 +0,0 @@
-#!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
-# Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
-# Fri Jan 18 12:46:00 2013 +0200
-#
-# Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
-
-"""Test K-Means algorithm
-"""
-import numpy
-
-import bob.core
-import bob.io
-from bob.io.base.test_utils import datafile
-
-from bob.learn.em import KMeansMachine, KMeansTrainer
-
-
-def equals(x, y, epsilon):
-    return (abs(x - y) < epsilon).all()
-
-
-def kmeans_plus_plus(machine, data, seed):
-    """Python implementation of K-Means++ (initialization)"""
-    n_data = data.shape[0]
-    rng = bob.core.random.mt19937(seed)
-    u = bob.core.random.uniform('int32', 0, n_data - 1)
-    index = u(rng)
-    machine.set_mean(0, data[index, :])
-    weights = numpy.zeros(shape=(n_data,), dtype=numpy.float64)
-
-    for m in range(1, machine.dim_c):
-        for s in range(n_data):
-            s_cur = data[s, :]
-            w_cur = machine.get_distance_from_mean(s_cur, 0)
-            for i in range(m):
-                w_cur = min(machine.get_distance_from_mean(s_cur, i), w_cur)
-            weights[s] = w_cur
-        weights *= weights
-        weights /= numpy.sum(weights)
-        d = bob.core.random.discrete('int32', weights)
-        index = d(rng)
-        machine.set_mean(m, data[index, :])
-
-
-def NormalizeStdArray(path):
-    array = bob.io.base.load(path).astype('float64')
-    std = array.std(axis=0)
-    return (array / std, std)
-
-
-def multiplyVectorsByFactors(matrix, vector):
-    for i in range(0, matrix.shape[0]):
-        for j in range(0, matrix.shape[1]):
-            matrix[i, j] *= vector[j]
-
-
-def flipRows(array):
-    if len(array.shape) == 2:
-        return numpy.array([numpy.array(array[1, :]), numpy.array(array[0, :])], 'float64')
-    elif len(array.shape) == 1:
-        return numpy.array([array[1], array[0]], 'float64')
-    else:
-        raise Exception('Input type not supportd by flipRows')
-
-
-if hasattr(KMeansTrainer, 'KMEANS_PLUS_PLUS'):
-    def test_kmeans_plus_plus():
-        # Tests the K-Means++ initialization
-        dim_c = 5
-        dim_d = 7
-        n_samples = 150
-        data = numpy.random.randn(n_samples, dim_d)
-        seed = 0
-
-        # C++ implementation
-        machine = KMeansMachine(dim_c, dim_d)
-        trainer = KMeansTrainer()
-        trainer.rng = bob.core.random.mt19937(seed)
-        trainer.initialization_method = 'KMEANS_PLUS_PLUS'
-        trainer.initialize(machine, data)
-
-        # Python implementation
-        py_machine = KMeansMachine(dim_c, dim_d)
-        kmeans_plus_plus(py_machine, data, seed)
-        assert equals(machine.means, py_machine.means, 1e-8)
-
-
-def test_kmeans_noduplicate():
-    # Data/dimensions
-    dim_c = 2
-    dim_d = 3
-    seed = 0
-    data = numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [4, 5, 6.]])
-    # Defines machine and trainer
-    machine = KMeansMachine(dim_c, dim_d)
-    trainer = KMeansTrainer()
-    rng = bob.core.random.mt19937(seed)
-    trainer.initialization_method = 'RANDOM_NO_DUPLICATE'
-    trainer.initialize(machine, data, rng)
-    # Makes sure that the two initial mean vectors selected are different
-    assert equals(machine.get_mean(0), machine.get_mean(1), 1e-8) == False
-
-
-def test_kmeans_a():
-    # Trains a KMeansMachine
-    # This files contains draws from two 1D Gaussian distributions:
-    #   * 100 samples from N(-10,1)
-    #   * 100 samples from N(10,1)
-    data = bob.io.base.load(datafile("samplesFrom2G_f64.hdf5", __name__, path="../data/"))
-
-    machine = KMeansMachine(2, 1)
-
-    trainer = KMeansTrainer()
-    # trainer.train(machine, data)
-    bob.learn.em.train(trainer, machine, data)
-
-    [variances, weights] = machine.get_variances_and_weights_for_each_cluster(data)
-    variances_b = numpy.ndarray(shape=(2, 1), dtype=numpy.float64)
-    weights_b = numpy.ndarray(shape=(2,), dtype=numpy.float64)
-    machine.__get_variances_and_weights_for_each_cluster_init__(variances_b, weights_b)
-    machine.__get_variances_and_weights_for_each_cluster_acc__(data, variances_b, weights_b)
-    machine.__get_variances_and_weights_for_each_cluster_fin__(variances_b, weights_b)
-    m1 = machine.get_mean(0)
-    m2 = machine.get_mean(1)
-
-    ## Check means [-10,10] / variances [1,1] / weights [0.5,0.5]
-    if (m1 < m2):
-        means = numpy.array(([m1[0], m2[0]]), 'float64')
-    else:
-        means = numpy.array(([m2[0], m1[0]]), 'float64')
-    assert equals(means, numpy.array([-10., 10.]), 2e-1)
-    assert equals(variances, numpy.array([1., 1.]), 2e-1)
-    assert equals(weights, numpy.array([0.5, 0.5]), 1e-3)
-
-    assert equals(variances, variances_b, 1e-8)
-    assert equals(weights, weights_b, 1e-8)
-
-
-def test_kmeans_b():
-    # Trains a KMeansMachine
-    (arStd, std) = NormalizeStdArray(datafile("faithful.torch3.hdf5", __name__, path="../data/"))
-
-    machine = KMeansMachine(2, 2)
-
-    trainer = KMeansTrainer()
-    # trainer.seed = 1337
-    bob.learn.em.train(trainer, machine, arStd, convergence_threshold=0.001)
-
-    [variances, weights] = machine.get_variances_and_weights_for_each_cluster(arStd)
-
-    means = numpy.array(machine.means)
-    variances = numpy.array(variances)
-
-    multiplyVectorsByFactors(means, std)
-    multiplyVectorsByFactors(variances, std ** 2)
-
-    gmmWeights = bob.io.base.load(datafile('gmm.init_weights.hdf5', __name__, path="../data/"))
-    gmmMeans = bob.io.base.load(datafile('gmm.init_means.hdf5', __name__, path="../data/"))
-    gmmVariances = bob.io.base.load(datafile('gmm.init_variances.hdf5', __name__, path="../data/"))
-
-    if (means[0, 0] < means[1, 0]):
-        means = flipRows(means)
-        variances = flipRows(variances)
-        weights = flipRows(weights)
-
-    assert equals(means, gmmMeans, 1e-3)
-    assert equals(weights, gmmWeights, 1e-3)
-    assert equals(variances, gmmVariances, 1e-3)
-
-    # Check that there is no duplicate means during initialization
-    machine = KMeansMachine(2, 1)
-    trainer = KMeansTrainer()
-    trainer.initialization_method = 'RANDOM_NO_DUPLICATE'
-    data = numpy.array([[1.], [1.], [1.], [1.], [1.], [1.], [2.], [3.]])
-    bob.learn.em.train(trainer, machine, data)
-    assert (numpy.isnan(machine.means).any()) == False
-
-
-def test_kmeans_parallel():
-    # Trains a KMeansMachine
-    (arStd, std) = NormalizeStdArray(datafile("faithful.torch3.hdf5", __name__, path="../data/"))
-
-    machine = KMeansMachine(2, 2)
-
-    trainer = KMeansTrainer()
-    # trainer.seed = 1337
-    
-    import multiprocessing.pool
-    pool = multiprocessing.pool.ThreadPool(3)
-    bob.learn.em.train(trainer, machine, arStd, convergence_threshold=0.001, pool = pool)
-
-    [variances, weights] = machine.get_variances_and_weights_for_each_cluster(arStd)
-
-    means = numpy.array(machine.means)
-    variances = numpy.array(variances)
-
-    multiplyVectorsByFactors(means, std)
-    multiplyVectorsByFactors(variances, std ** 2)
-
-    gmmWeights = bob.io.base.load(datafile('gmm.init_weights.hdf5', __name__, path="../data/"))
-    gmmMeans = bob.io.base.load(datafile('gmm.init_means.hdf5', __name__, path="../data/"))
-    gmmVariances = bob.io.base.load(datafile('gmm.init_variances.hdf5', __name__, path="../data/"))
-
-    if (means[0, 0] < means[1, 0]):
-        means = flipRows(means)
-        variances = flipRows(variances)
-        weights = flipRows(weights)
-
-    assert equals(means, gmmMeans, 1e-3)
-    assert equals(weights, gmmWeights, 1e-3)
-    assert equals(variances, gmmVariances, 1e-3)
-
-
-def test_trainer_execption():
-    from nose.tools import assert_raises
-
-    # Testing Inf
-    machine = KMeansMachine(2, 2)
-    data = numpy.array([[1.0, 2.0], [2, 3.], [1, 1.], [2, 5.], [numpy.inf, 1.0]])
-    trainer = KMeansTrainer()
-    assert_raises(ValueError, bob.learn.em.train, trainer, machine, data, 10)
-
-    # Testing Nan
-    machine = KMeansMachine(2, 2)
-    data = numpy.array([[1.0, 2.0], [2, 3.], [1, numpy.nan], [2, 5.], [2.0, 1.0]])
-    trainer = KMeansTrainer()
-    assert_raises(ValueError, bob.learn.em.train, trainer, machine, data, 10)
diff --git a/conda/meta.yaml b/conda/meta.yaml
index 125d5691ef595927743892b3a4bff15a1c947245..eb5cdbab872ca7f7ed1173c78b3e7fcfb4f67abe 100644
--- a/conda/meta.yaml
+++ b/conda/meta.yaml
@@ -37,11 +37,15 @@ requirements:
     - libblitz {{ libblitz }}
     - boost {{ boost }}
     - numpy {{ numpy }}
+    - dask {{ dask }}
+    - dask-ml {{ dask_ml }}
   run:
     - python
     - setuptools
     - boost
     - {{ pin_compatible('numpy') }}
+    - {{ pin_compatible('dask') }}
+    - {{ pin_compatible('dask-ml') }}
 
 test:
   imports:
diff --git a/doc/plot/plot_kmeans.py b/doc/plot/plot_kmeans.py
index b56f42999e7b2b9c2515b747905b709a77258fde..e35fedbc9469f8cf0377d794b8ad325037636f15 100644
--- a/doc/plot/plot_kmeans.py
+++ b/doc/plot/plot_kmeans.py
@@ -1,4 +1,5 @@
-import bob.learn.em
+from bob.learn.em.cluster import KMeansMachine
+from bob.learn.em.cluster import KMeansTrainer
 import bob.db.iris
 import numpy
 import matplotlib.pyplot as plt
@@ -14,11 +15,12 @@ virginica = numpy.column_stack(
 data = numpy.vstack((setosa, versicolor, virginica))
 
 # Training KMeans
-# Two clusters with a feature dimensionality of 3
-machine = bob.learn.em.KMeansMachine(3, 2)
-trainer = bob.learn.em.KMeansTrainer()
-bob.learn.em.train(trainer, machine, data, max_iterations=200,
-                   convergence_threshold=1e-5)  # Train the KMeansMachine
+# 3 clusters with a feature dimensionality of 2
+machine = KMeansMachine(n_clusters=3)
+trainer = KMeansTrainer(init_method="k-means++")
+machine.fit(data, trainer=trainer)
+
+predictions = machine.predict(data)
 
 # Plotting
 figure, ax = plt.subplots()
@@ -28,8 +30,8 @@ plt.scatter(versicolor[:, 0],
             versicolor[:, 1], c="goldenrod", label="versicolor")
 plt.scatter(virginica[:, 0],
             virginica[:, 1], c="dimgrey", label="virginica")
-plt.scatter(machine.means[:, 0],
-            machine.means[:, 1], c="blue", marker="x", label="centroids",
+plt.scatter(machine.centroids_[:, 0],
+            machine.centroids_[:, 1], c="blue", marker="x", label="centroids",
             s=60)
 plt.legend()
 plt.xticks([], [])
diff --git a/requirements.txt b/requirements.txt
index 77918d31c7c8099fd269eec2e61a9268d0e40391..ddb7e6e216a56cd8784e952a1b15fce46553bedd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,4 @@ bob.io.base
 bob.sp
 bob.math > 2
 bob.learn.activation
-bob.learn.linear
\ No newline at end of file
+bob.learn.linear