diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index c2ebe217fc5884a47e5be9c4b837b4f2ef12d411..88efe81af0e7a5918cfc3f604adf3fc906aec83d 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -5,23 +5,123 @@
 """This module provides classes and functions for the training and usage of GMM."""
 
 import copy
+import functools
 import logging
+import operator
 
 from typing import Union
 
+import dask
 import dask.array as da
 import numpy as np
 
 from h5py import File as HDF5File
 from sklearn.base import BaseEstimator
 
-from .k_means import KMeansMachine
+from .k_means import (
+    KMeansMachine,
+    array_to_delayed_list,
+    check_and_persist_dask_input,
+)
 
 logger = logging.getLogger(__name__)
 
 EPSILON = np.finfo(float).eps
 
 
+def logaddexp_reduce(array, axis=0, keepdims=False):
+    return np.logaddexp.reduce(
+        array, axis=axis, keepdims=keepdims, initial=-np.inf
+    )
+
+
+def e_step(data, weights, means, variances, g_norms, log_weights):
+    # Ensure data is a series of samples (2D array)
+    data = np.atleast_2d(data)
+
+    n_gaussians = len(weights)
+
+    # Allow the absence of previous statistics
+    statistics = GMMStats(n_gaussians, data.shape[-1])
+
+    # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
+    z = np.empty_like(data, shape=(n_gaussians, len(data)))
+    for i in range(n_gaussians):
+        z[i] = np.sum((data - means[i]) ** 2 / variances[i], axis=-1)
+    ll = -0.5 * (g_norms[:, None] + z)
+    log_weighted_likelihoods = log_weights[:, None] + ll
+
+    # Log likelihood [array of shape (n_samples,)]
+    if isinstance(log_weighted_likelihoods, np.ndarray):
+        log_likelihood = logaddexp_reduce(log_weighted_likelihoods)
+    else:
+        # Sum along gaussians axis (using logAddExp to prevent underflow)
+        log_likelihood = da.reduction(
+            x=log_weighted_likelihoods,
+            chunk=logaddexp_reduce,
+            aggregate=logaddexp_reduce,
+            axis=0,
+            dtype=float,
+            keepdims=False,
+        )
+
+    # Responsibility P [array of shape (n_gaussians, n_samples)]
+    responsibility = np.exp(log_weighted_likelihoods - log_likelihood[None, :])
+
+    # Accumulate
+
+    # Total likelihood [float]
+    statistics.log_likelihood += log_likelihood.sum()
+    # Count of samples [int]
+    statistics.t += data.shape[0]
+    # Responsibilities [array of shape (n_gaussians,)]
+    statistics.n = statistics.n + responsibility.sum(axis=-1)
+    for i in range(n_gaussians):
+        # p * x [array of shape (n_gaussians, n_samples, n_features)]
+        px = responsibility[i, :, None] * data
+        # First order stats [array of shape (n_gaussians, n_features)]
+        statistics.sum_px[i] = statistics.sum_px[i] + np.sum(px, axis=0)
+        # Second order stats [array of shape (n_gaussians, n_features)]
+        statistics.sum_pxx[i] = statistics.sum_pxx[i] + np.sum(
+            px * data, axis=0
+        )
+
+    # px = np.multiply(responsibility[:, :, None], data[None, :, :])
+    # statistics.sum_px = statistics.sum_px + px.sum(axis=1)
+    # pxx = np.multiply(px[:, :, :], data[None, :, :])
+    # statistics.sum_pxx = statistics.sum_pxx + pxx.sum(axis=1)
+
+    return statistics
+
+
+def m_step(
+    machine,
+    statistics,
+    update_means,
+    update_variances,
+    update_weights,
+    mean_var_update_threshold,
+    map_relevance_factor,
+    map_alpha,
+    trainer,
+):
+    m_step_func = map_gmm_m_step if trainer == "map" else ml_gmm_m_step
+    statistics = functools.reduce(operator.iadd, statistics)
+    m_step_func(
+        machine,
+        statistics=statistics,
+        update_means=update_means,
+        update_variances=update_variances,
+        update_weights=update_weights,
+        mean_var_update_threshold=mean_var_update_threshold,
+        reynolds_adaptation=map_relevance_factor is not None,
+        alpha=map_alpha,
+        relevance_factor=map_relevance_factor,
+    )
+    average_output = float(statistics.log_likelihood / statistics.t)
+    return machine, average_output
+
+
 class GMMStats:
     """Stores accumulated statistics of a GMM.
 
@@ -403,9 +503,7 @@ class GMMMachine(BaseEstimator):
         self._variances = np.maximum(self.variance_thresholds, variances)
         # Recompute g_norm for each gaussian [array of shape (n_gaussians,)]
         n_log_2pi = self._variances.shape[-1] * np.log(2 * np.pi)
-        self._g_norms = np.array(
-            n_log_2pi + np.log(self._variances).sum(axis=-1)
-        )
+        self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1)
 
     @property
     def variance_thresholds(self):
@@ -580,9 +678,8 @@ class GMMMachine(BaseEstimator):
             ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data)
 
             # Set the GMM machine's gaussians with the results of k-means
-            self.means = np.array(copy.deepcopy(kmeans_machine.centroids_))
-            self.variances = np.array(copy.deepcopy(variances))
-            self.weights = np.array(copy.deepcopy(weights))
+            self.means = copy.deepcopy(kmeans_machine.centroids_)
+            self.variances, self.weights = dask.compute(variances, weights)
 
     def log_weighted_likelihood(
         self,
@@ -733,8 +830,11 @@ class GMMMachine(BaseEstimator):
             **kwargs,
         )
 
-    def fit(self, X, y=None, **kwargs):
+    def fit(self, X, y=None):
         """Trains the GMM on data until convergence or maximum step is reached."""
+
+        input_is_dask = check_and_persist_dask_input(X)
+
         if self._means is None:
             self.initialize_gaussians(X)
         else:
@@ -746,6 +846,19 @@ class GMMMachine(BaseEstimator):
             )
             self.variances = np.ones_like(self.means)
 
+        m_step_func = functools.partial(
+            m_step,
+            update_means=self.update_means,
+            update_variances=self.update_variances,
+            update_weights=self.update_weights,
+            mean_var_update_threshold=self.mean_var_update_threshold,
+            map_relevance_factor=self.map_relevance_factor,
+            map_alpha=self.map_alpha,
+            trainer=self.trainer,
+        )
+
+        X = array_to_delayed_list(X, input_is_dask)
+
         average_output = 0
         logger.info("Training GMM...")
         step = 0
@@ -761,23 +874,39 @@ class GMMMachine(BaseEstimator):
             )
 
             average_output_previous = average_output
-            stats = self.e_step(X)
-            self.m_step(
-                stats=stats,
-            )
 
-            # if we're running in dask, persist weights, means, and variances so
-            # we don't recompute each step.
-            for attr in ["weights", "means", "variances"]:
-                arr = getattr(self, attr)
-                if isinstance(arr, da.Array):
-                    setattr(self, attr, arr.persist())
+            # compute the e-m steps
+            if input_is_dask:
+                stats = [
+                    dask.delayed(e_step)(
+                        data=xx,
+                        weights=self.weights,
+                        means=self.means,
+                        variances=self.variances,
+                        g_norms=self.g_norms,
+                        log_weights=self.log_weights,
+                    )
+                    for xx in X
+                ]
+                new_machine, average_output = dask.compute(
+                    dask.delayed(m_step_func)(self, stats)
+                )[0]
+                for attr in ["weights", "means", "variances"]:
+                    setattr(self, attr, getattr(new_machine, attr))
+            else:
+                stats = [
+                    e_step(
+                        data=X,
+                        weights=self.weights,
+                        means=self.means,
+                        variances=self.variances,
+                        g_norms=self.g_norms,
+                        log_weights=self.log_weights,
+                    )
+                ]
+                _, average_output = m_step_func(self, stats)
 
-            # Note: Uses the stats from before m_step, leading to an additional m_step
-            # (which is not bad because it will always converge)
-            average_output = float(stats.log_likelihood / stats.t)
             logger.debug(f"log likelihood = {average_output}")
-
             if step > 1:
                 convergence_value = abs(
                     (average_output_previous - average_output)
@@ -794,6 +923,7 @@ class GMMMachine(BaseEstimator):
                         "Reached convergence threshold. Training stopped."
                     )
                     break
+
         else:
             logger.info(
                 "Reached maximum step. Training stopped without convergence."
diff --git a/bob/learn/em/k_means.py b/bob/learn/em/k_means.py
index 3b73a610fd1beee0fbbd2f1a6844818d4965caa7..52f025676e481b51714a6c5c3aca62a145e6ab4b 100644
--- a/bob/learn/em/k_means.py
+++ b/bob/learn/em/k_means.py
@@ -4,10 +4,13 @@
 
 import logging
 
-from typing import Tuple, Union
+from typing import Union
 
+import dask
 import dask.array as da
+import dask.delayed
 import numpy as np
+import scipy.spatial.distance
 
 from dask_ml.cluster.k_means import k_init
 from sklearn.base import BaseEstimator
@@ -15,6 +18,125 @@ from sklearn.base import BaseEstimator
 logger = logging.getLogger(__name__)
 
 
+def get_centroids_distance(x: np.ndarray, means: np.ndarray) -> np.ndarray:
+    """Returns the distance values between x and each cluster's centroid.
+
+    The returned values are squared Euclidean distances.
+
+    Parameters
+    ----------
+    x: ndarray of shape (n_samples, n_features)
+        A series of data points.
+    means: ndarray of shape (n_clusters, n_features)
+        The centroids.
+
+    Returns
+    -------
+    distances: ndarray of shape (n_clusters, n_samples)
+        For each cluster, the squared Euclidian distance (or distances) to x.
+    """
+    x = np.atleast_2d(x)
+    if isinstance(x, da.Array):
+        return np.sum((means[:, None] - x[None, :]) ** 2, axis=-1)
+    else:
+        return scipy.spatial.distance.cdist(means, x, metric="sqeuclidean")
+
+
+def get_closest_centroid_index(centroids_dist: np.ndarray) -> np.ndarray:
+    """Returns the index of the closest cluster mean to x."""
+    return np.argmin(centroids_dist, axis=0)
+
+
+def e_step(data, means):
+    """Computes the zero-th and first order statistics and average min distance
+    for each data point.
+
+    Parameters
+    ----------
+    data : array-like, shape (n_samples, n_features)
+        The data.
+
+    means : array-like, shape (n_clusters, n_features)
+        The cluster centers.
+
+
+    Returns
+    -------
+    zeroeth_order_statistics : array-like, shape (n_samples,)
+        The zero-th order statistics.
+    first_order_statistics : array-like, shape (n_samples, n_clusters)
+        The first order statistics.
+    avg_min_dist : float
+    """
+    n_clusters = len(means)
+    distances = get_centroids_distance(data, means)
+    closest_k_indices = get_closest_centroid_index(distances)
+    zeroeth_order_statistics = np.bincount(
+        closest_k_indices, minlength=n_clusters
+    )
+    first_order_statistics = np.sum(
+        np.eye(n_clusters)[closest_k_indices][:, :, None] * data[:, None],
+        axis=0,
+    )
+    min_distance = np.min(distances, axis=0)
+    average_min_distance = min_distance.mean()
+    return (
+        zeroeth_order_statistics,
+        first_order_statistics,
+        average_min_distance,
+    )
+
+
+def m_step(stats, n_samples):
+    """Computes the cluster centers and average minimum distance.
+
+    Parameters
+    ----------
+    stats : list
+        A list which contains the results of the :any:`e_step` function applied
+        on each chunk of data.
+    n_samples : int
+        The total number of samples.
+
+    Returns
+    -------
+    means : array-like, shape (n_clusters, n_features)
+        The cluster centers.
+    avg_min_dist : float
+        The average minimum distance.
+    """
+    (
+        zeroeth_order_statistics,
+        first_order_statistics,
+        average_min_distance,
+    ) = (0, 0, 0)
+    for zeroeth_, first_, average_ in stats:
+        zeroeth_order_statistics += zeroeth_
+        first_order_statistics += first_
+        average_min_distance += average_
+    average_min_distance /= n_samples
+
+    means = first_order_statistics / zeroeth_order_statistics[:, None]
+    return means, average_min_distance
+
+
+def check_and_persist_dask_input(data):
+    # check if input is a dask array. If so, persist and rebalance data
+    input_is_dask = False
+    if isinstance(data, da.Array):
+        data: da.Array = data.persist()
+        input_is_dask = True
+    return input_is_dask
+
+
+def array_to_delayed_list(data, input_is_dask):
+    # If input is a dask array, convert to delayed chunks
+    if input_is_dask:
+        data = data.to_delayed().ravel().tolist()
+        logger.debug(f"Got {len(data)} chunks.")
+    return data
+
+
 class KMeansMachine(BaseEstimator):
     """Stores the k-means clusters parameters (centroid of each cluster).
 
@@ -84,43 +206,6 @@ class KMeansMachine(BaseEstimator):
     def means(self, value: np.ndarray):
         self.centroids_ = value
 
-    def get_centroids_distance(self, x: np.ndarray) -> np.ndarray:
-        """Returns the distance values between x and each cluster's centroid.
-
-        The returned values are squared Euclidean distances.
-
-        Parameters
-        ----------
-        x: ndarray of shape (n_features,) or (n_samples, n_features)
-            One data point, or a series of data points.
-
-        Returns
-        -------
-        distances: ndarray of shape (n_clusters,) or (n_clusters, n_samples)
-            For each cluster, the squared Euclidian distance (or distances) to x.
-        """
-        return np.sum((self.centroids_[:, None] - x[None, :]) ** 2, axis=-1)
-
-    def get_closest_centroid(self, x: np.ndarray) -> Tuple[int, float]:
-        """Returns the closest mean's index and squared Euclidian distance to x."""
-        dists = self.get_centroids_distance(x)
-        min_id = np.argmin(dists, axis=0)
-        min_dist = dists[min_id]
-        return min_id, min_dist
-
-    def get_closest_centroid_index(self, x: np.ndarray) -> np.ndarray:
-        """Returns the index of the closest cluster mean to x."""
-        return np.argmin(self.get_centroids_distance(x), axis=0)
-
-    def get_min_distance(self, x: np.ndarray) -> np.ndarray:
-        """Returns the smallest distance between that point and the clusters centroids.
-
-        For each point in x, the minimum distance to each cluster's mean is returned.
-
-        The returned values are squared Euclidean distances.
-        """
-        return np.min(self.get_centroids_distance(x), axis=0)
-
     def __eq__(self, obj) -> bool:
         return self.is_similar_to(obj, r_epsilon=0, a_epsilon=0)
 
@@ -157,7 +242,8 @@ class KMeansMachine(BaseEstimator):
                 Weight (proportion of quantity of data point) of each cluster.
         """
         n_cluster = self.n_clusters
-        closest_centroid_indices = self.get_closest_centroid_index(data)
+        dist = get_centroids_distance(data, self.centroids_)
+        closest_centroid_indices = get_closest_centroid_index(dist)
         weights_count = np.bincount(
             closest_centroid_indices, minlength=n_cluster
         )
@@ -198,30 +284,18 @@ class KMeansMachine(BaseEstimator):
             oversampling_factor=self.oversampling_factor,
         )
 
-    def e_step(self, data: np.ndarray):
-        closest_k_indices = self.get_closest_centroid_index(data)
-        # Number of data points in each cluster
-        self.zeroeth_order_statistics = np.bincount(
-            closest_k_indices, minlength=self.n_clusters
-        )
-        # Sum of data points coordinates in each cluster
-        self.first_order_statistics = np.sum(
-            np.eye(self.n_clusters)[closest_k_indices][:, :, None]
-            * data[:, None],
-            axis=0,
-        )
-        self.average_min_distance = self.get_min_distance(data).mean()
-
-    def m_step(self, data: np.ndarray):
-        self.centroids_ = (
-            self.first_order_statistics / self.zeroeth_order_statistics[:, None]
-        )
-
     def fit(self, X, y=None):
         """Fits this machine on data samples."""
+
+        input_is_dask = check_and_persist_dask_input(X)
+
         logger.debug("Initializing trainer.")
         self.initialize(data=X)
 
+        n_samples = len(X)
+
+        X = array_to_delayed_list(X, input_is_dask)
+
         logger.info("Training k-means.")
         distance = np.inf
         step = 0
@@ -232,17 +306,22 @@ class KMeansMachine(BaseEstimator):
                 + (f"/{self.max_iter:3d}" if self.max_iter else "")
             )
             distance_previous = distance
-            self.e_step(data=X)
-            self.m_step(data=X)
 
-            # If we're running in dask, persist the centroids so we don't recompute them
-            # from the start of the graph at every step.
-            for attr in ("centroids_",):
-                arr = getattr(self, attr)
-                if isinstance(arr, da.Array):
-                    setattr(self, attr, arr.persist())
+            # compute the e-m steps
+            if input_is_dask:
+                stats = [
+                    dask.delayed(e_step)(xx, means=self.centroids_) for xx in X
+                ]
+                self.centroids_, self.average_min_distance = dask.compute(
+                    dask.delayed(m_step)(stats, n_samples)
+                )[0]
+            else:
+                stats = [e_step(X, means=self.centroids_)]
+                self.centroids_, self.average_min_distance = m_step(
+                    stats, n_samples
+                )
 
-            distance = float(self.average_min_distance)
+            distance = self.average_min_distance
 
             logger.debug(
                 f"Average minimal squared Euclidean distance = {distance}"
@@ -263,22 +342,11 @@ class KMeansMachine(BaseEstimator):
                         "Reached convergence threshold. Training stopped."
                     )
                     break
+
         else:
             logger.info(
                 "Reached maximum step. Training stopped without convergence."
             )
-        self.compute()
-        return self
-
-    def partial_fit(self, X, y=None):
-        """Applies one e-m step of k-means on the data."""
-        if self.centroids_ is None:
-            logger.debug("First call to 'partial_fit'. Initializing...")
-            self.initialize(data=X)
-
-        self.e_step(data=X)
-        self.m_step(data=X)
-
         return self
 
     def transform(self, X):
@@ -294,7 +362,7 @@ class KMeansMachine(BaseEstimator):
         distances: ndarray of shape (n_clusters, n_samples)
             For each mean, for each point, the squared Euclidian distance between them.
         """
-        return self.get_centroids_distance(X)
+        return get_centroids_distance(X, self.centroids_)
 
     def predict(self, X):
         """Returns the labels of the closest cluster centroid to the data.
@@ -309,9 +377,6 @@ class KMeansMachine(BaseEstimator):
         indices: ndarray of shape (n_samples)
             The indices of the closest cluster for each data point.
         """
-        return self.get_closest_centroid_index(X)
-
-    def compute(self, *args, **kwargs):
-        """Computes delayed arrays if needed."""
-        for name in ("centroids_",):
-            setattr(self, name, np.asarray(getattr(self, name)))
+        return get_closest_centroid_index(
+            get_centroids_distance(X, self.centroids_)
+        )
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 7eda1c63df3a7827ce7631eb6a0c78ea0971f5f7..3ad780a32be84b20d27b2cb14f79001309568e7d 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -16,11 +16,14 @@ from copy import deepcopy
 import dask.array as da
 import numpy as np
 
+from dask.distributed import Client
 from h5py import File as HDF5File
 from pkg_resources import resource_filename
 
 from bob.learn.em import GMMMachine, GMMStats, KMeansMachine
 
+from .test_kmeans import to_dask_array, to_numpy
+
 
 def load_array(filename):
     with HDF5File(filename, "r") as f:
@@ -464,13 +467,22 @@ def test_gmm_kmeans_parallel_init():
     data = np.array(
         [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]]
     )
-    machine = machine.fit(data)
-    expected_means = np.array([[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]])
-    expected_variances = np.array(
-        [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
-    )
-    np.testing.assert_almost_equal(machine.means, expected_means, decimal=3)
-    np.testing.assert_almost_equal(machine.variances, expected_variances)
+    with Client().as_current():
+        for transform in (to_numpy, to_dask_array):
+            data = transform(data)
+            machine = machine.fit(data)
+            expected_means = np.array(
+                [[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]]
+            )
+            expected_variances = np.array(
+                [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]]
+            )
+            np.testing.assert_almost_equal(
+                machine.means, expected_means, decimal=3
+            )
+            np.testing.assert_almost_equal(
+                machine.variances, expected_variances
+            )
 
 
 def test_likelihood():
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index df066c89c945c18c2a860e20f9c441d2236e873c..1400b469e909973e3a7911c7e7439a55b7d87a47 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -14,8 +14,9 @@ import copy
 
 import dask.array as da
 import numpy as np
+import scipy.spatial.distance
 
-from bob.learn.em import KMeansMachine
+from bob.learn.em import KMeansMachine, k_means
 
 
 def to_numpy(*args):
@@ -30,7 +31,10 @@ def to_numpy(*args):
 def to_dask_array(*args):
     result = []
     for x in args:
-        result.append(da.from_array(np.array(x)))
+        x = np.asarray(x)
+        chunks = list(x.shape)
+        chunks[0] //= 2
+        result.append(da.from_array(x, chunks=chunks))
     if len(result) == 1:
         return result[0]
     return result
@@ -54,23 +58,12 @@ def test_KMeansMachine():
         np.testing.assert_equal(km.transform(test_val)[0], np.array([1]))
         np.testing.assert_equal(km.transform(test_val)[1], np.array([6]))
 
-        (index, dist) = km.get_closest_centroid(test_val)
-        assert index == 0
-        np.testing.assert_equal(dist, np.array([[1.0]]))
-
-        (indices, dists) = km.get_closest_centroid(test_arr)
-        np.testing.assert_equal(indices, np.array([0, 1]))
-        np.testing.assert_equal(dists, np.array([[1, 8], [6, 1]]))
-
         index = km.predict(test_val)
         assert index == 0
 
         indices = km.predict(test_arr)
         np.testing.assert_equal(indices, np.array([0, 1]))
 
-        np.testing.assert_equal(km.get_min_distance(test_val), np.array([1]))
-        np.testing.assert_equal(km.get_min_distance(test_arr), np.array([1, 1]))
-
         # Check __eq__ and is_similar_to
         km2 = KMeansMachine(2)
         assert km != km2
@@ -124,25 +117,6 @@ def test_kmeans_fit():
         machine.fit(data)
 
 
-def test_kmeans_fit_partial():
-    np.random.seed(0)
-    data1 = np.random.normal(loc=1, size=(2000, 3))
-    data2 = np.random.normal(loc=-1, size=(2000, 3))
-    data = np.concatenate([data1, data2], axis=0)
-
-    for transform in (to_numpy, to_dask_array):
-        data = transform(data)
-        machine = KMeansMachine(2, random_state=0)
-        for _ in range(20):
-            machine.partial_fit(data)
-        centroids = machine.centroids_[np.argsort(machine.centroids_[:, 0])]
-        expected = [
-            [-1.07173464, -1.06200356, -1.00724920],
-            [0.99479125, 0.99665564, 0.97689017],
-        ]
-        np.testing.assert_almost_equal(centroids, expected, decimal=7)
-
-
 def test_kmeans_fit_init_pp():
     np.random.seed(0)
     data1 = np.random.normal(loc=1, size=(2000, 3))
@@ -201,3 +175,18 @@ def test_kmeans_parameters():
             [0.99479125, 0.99665564, 0.97689017],
         ]
         np.testing.assert_almost_equal(centroids, expected, decimal=7)
+
+
+def test_get_centroids_distance():
+    np.random.seed(0)
+    n_features = 60
+    n_samples = 240_000
+    n_clusters = 256
+    data = np.random.normal(loc=1, size=(n_samples, n_features))
+    means = np.random.normal(loc=-1, size=(n_clusters, n_features))
+    oracle = scipy.spatial.distance.cdist(means, data, metric="sqeuclidean")
+    for transform in (to_numpy,):
+        data, means = transform(data, means)
+        dist = k_means.get_centroids_distance(data, means)
+        np.testing.assert_allclose(dist, oracle)
+        assert type(data) is type(dist), (type(data), type(dist))