From e80d044e250804a745e250315217b20b7414f34e Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 9 Mar 2022 16:31:04 +0100
Subject: [PATCH] [kmeans] refactor e-m to have simpler graphs for dask

---
 bob/learn/em/gmm.py              |   2 +-
 bob/learn/em/k_means.py          | 239 ++++++++++++++++++++-----------
 bob/learn/em/test/test_kmeans.py |  35 +----
 3 files changed, 158 insertions(+), 118 deletions(-)

diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index c2ebe21..ba0ccb6 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -733,7 +733,7 @@ class GMMMachine(BaseEstimator):
             **kwargs,
         )
 
-    def fit(self, X, y=None, **kwargs):
+    def fit(self, X, y=None):
         """Trains the GMM on data until convergence or maximum step is reached."""
         if self._means is None:
             self.initialize_gaussians(X)
diff --git a/bob/learn/em/k_means.py b/bob/learn/em/k_means.py
index 3b73a61..5d7c95e 100644
--- a/bob/learn/em/k_means.py
+++ b/bob/learn/em/k_means.py
@@ -4,9 +4,13 @@
 
 import logging
 
-from typing import Tuple, Union
+from typing import Union
 
+import dask
 import dask.array as da
+import dask.bag
+import dask.delayed
+import distributed
 import numpy as np
 
 from dask_ml.cluster.k_means import k_init
@@ -15,6 +19,102 @@ from sklearn.base import BaseEstimator
 logger = logging.getLogger(__name__)
 
 
+def get_centroids_distance(x: np.ndarray, means: np.ndarray) -> np.ndarray:
+    """Returns the distance values between x and each cluster's centroid.
+
+    The returned values are squared Euclidean distances.
+
+    Parameters
+    ----------
+    x: ndarray of shape (n_features,) or (n_samples, n_features)
+        One data point, or a series of data points.
+
+    Returns
+    -------
+    distances: ndarray of shape (n_clusters,) or (n_clusters, n_samples)
+        For each cluster, the squared Euclidian distance (or distances) to x.
+    """
+    return np.sum((means[:, None] - x[None, :]) ** 2, axis=-1)
+
+
+def get_closest_centroid_index(centroids_dist: np.ndarray) -> np.ndarray:
+    """Returns the index of the closest cluster mean to x."""
+    return np.argmin(centroids_dist, axis=0)
+
+
+def e_step(data, means):
+    """Computes the zero-th and first order statistics and average min distance
+    for each data point.
+
+    Parameters
+    ----------
+    data : array-like, shape (n_samples, n_features)
+        The data.
+
+    means : array-like, shape (n_clusters, n_features)
+        The cluster centers.
+
+
+    Returns
+    -------
+    zeroeth_order_statistics : array-like, shape (n_samples,)
+        The zero-th order statistics.
+    first_order_statistics : array-like, shape (n_samples, n_clusters)
+        The first order statistics.
+    avg_min_dist : float
+    """
+    n_clusters = len(means)
+    distances = get_centroids_distance(data, means)
+    closest_k_indices = get_closest_centroid_index(distances)
+    zeroeth_order_statistics = np.bincount(
+        closest_k_indices, minlength=n_clusters
+    )
+    first_order_statistics = np.sum(
+        np.eye(n_clusters)[closest_k_indices][:, :, None] * data[:, None],
+        axis=0,
+    )
+    min_distance = np.min(distances, axis=0)
+    average_min_distance = min_distance.mean()
+    return (
+        zeroeth_order_statistics,
+        first_order_statistics,
+        average_min_distance,
+    )
+
+
+def m_step(stats, n_samples):
+    """Computes the cluster centers and average minimum distance.
+
+    Parameters
+    ----------
+    stats : list
+        A list which contains the results of the :any:`e_step` function applied
+        on each chunk of data.
+    n_samples : int
+        The total number of samples.
+
+    Returns
+    -------
+    means : array-like, shape (n_clusters, n_features)
+        The cluster centers.
+    avg_min_dist : float
+        The average minimum distance.
+    """
+    (
+        zeroeth_order_statistics,
+        first_order_statistics,
+        average_min_distance,
+    ) = (0, 0, 0)
+    for zeroeth_, first_, average_ in stats:
+        zeroeth_order_statistics += zeroeth_
+        first_order_statistics += first_
+        average_min_distance += average_
+    average_min_distance /= n_samples
+
+    means = first_order_statistics / zeroeth_order_statistics[:, None]
+    return means, average_min_distance
+
+
 class KMeansMachine(BaseEstimator):
     """Stores the k-means clusters parameters (centroid of each cluster).
 
@@ -84,43 +184,6 @@ class KMeansMachine(BaseEstimator):
     def means(self, value: np.ndarray):
         self.centroids_ = value
 
-    def get_centroids_distance(self, x: np.ndarray) -> np.ndarray:
-        """Returns the distance values between x and each cluster's centroid.
-
-        The returned values are squared Euclidean distances.
-
-        Parameters
-        ----------
-        x: ndarray of shape (n_features,) or (n_samples, n_features)
-            One data point, or a series of data points.
-
-        Returns
-        -------
-        distances: ndarray of shape (n_clusters,) or (n_clusters, n_samples)
-            For each cluster, the squared Euclidian distance (or distances) to x.
-        """
-        return np.sum((self.centroids_[:, None] - x[None, :]) ** 2, axis=-1)
-
-    def get_closest_centroid(self, x: np.ndarray) -> Tuple[int, float]:
-        """Returns the closest mean's index and squared Euclidian distance to x."""
-        dists = self.get_centroids_distance(x)
-        min_id = np.argmin(dists, axis=0)
-        min_dist = dists[min_id]
-        return min_id, min_dist
-
-    def get_closest_centroid_index(self, x: np.ndarray) -> np.ndarray:
-        """Returns the index of the closest cluster mean to x."""
-        return np.argmin(self.get_centroids_distance(x), axis=0)
-
-    def get_min_distance(self, x: np.ndarray) -> np.ndarray:
-        """Returns the smallest distance between that point and the clusters centroids.
-
-        For each point in x, the minimum distance to each cluster's mean is returned.
-
-        The returned values are squared Euclidean distances.
-        """
-        return np.min(self.get_centroids_distance(x), axis=0)
-
     def __eq__(self, obj) -> bool:
         return self.is_similar_to(obj, r_epsilon=0, a_epsilon=0)
 
@@ -157,7 +220,8 @@ class KMeansMachine(BaseEstimator):
                 Weight (proportion of quantity of data point) of each cluster.
         """
         n_cluster = self.n_clusters
-        closest_centroid_indices = self.get_closest_centroid_index(data)
+        dist = get_centroids_distance(data, self.centroids_)
+        closest_centroid_indices = get_closest_centroid_index(dist)
         weights_count = np.bincount(
             closest_centroid_indices, minlength=n_cluster
         )
@@ -198,30 +262,31 @@ class KMeansMachine(BaseEstimator):
             oversampling_factor=self.oversampling_factor,
         )
 
-    def e_step(self, data: np.ndarray):
-        closest_k_indices = self.get_closest_centroid_index(data)
-        # Number of data points in each cluster
-        self.zeroeth_order_statistics = np.bincount(
-            closest_k_indices, minlength=self.n_clusters
-        )
-        # Sum of data points coordinates in each cluster
-        self.first_order_statistics = np.sum(
-            np.eye(self.n_clusters)[closest_k_indices][:, :, None]
-            * data[:, None],
-            axis=0,
-        )
-        self.average_min_distance = self.get_min_distance(data).mean()
-
-    def m_step(self, data: np.ndarray):
-        self.centroids_ = (
-            self.first_order_statistics / self.zeroeth_order_statistics[:, None]
-        )
-
     def fit(self, X, y=None):
         """Fits this machine on data samples."""
+
+        # check if input is a dask array. If so, persist and rebalance data
+        client, input_is_dask = None, False
+        if isinstance(X, da.Array):
+            X: da.Array = X.persist()
+            input_is_dask = True
+
+            try:
+                client = distributed.Client.current()
+                client.rebalance()
+            except ValueError:
+                pass
+
         logger.debug("Initializing trainer.")
         self.initialize(data=X)
 
+        n_samples = len(X)
+
+        # If input is a dask array, convert to delayed chunks
+        if input_is_dask:
+            X = X.to_delayed()
+            logger.debug(f"Got {len(X)} chunks.")
+
         logger.info("Training k-means.")
         distance = np.inf
         step = 0
@@ -232,17 +297,34 @@ class KMeansMachine(BaseEstimator):
                 + (f"/{self.max_iter:3d}" if self.max_iter else "")
             )
             distance_previous = distance
-            self.e_step(data=X)
-            self.m_step(data=X)
 
-            # If we're running in dask, persist the centroids so we don't recompute them
-            # from the start of the graph at every step.
-            for attr in ("centroids_",):
-                arr = getattr(self, attr)
-                if isinstance(arr, da.Array):
-                    setattr(self, attr, arr.persist())
+            # compute the e-m steps
+            if input_is_dask:
+                stats = [
+                    dask.delayed(e_step)(xx, means=self.centroids_)
+                    for xx in X.ravel().tolist()
+                ]
+                self.centroids_, self.average_min_distance = dask.compute(
+                    dask.delayed(m_step)(stats, n_samples)
+                )[0]
+            else:
+                stats = [e_step(X, means=self.centroids_)]
+                self.centroids_, self.average_min_distance = m_step(
+                    stats, n_samples
+                )
+
+            # scatter centroids to all workers for efficiency
+            if client is not None:
+                logger.debug("Broadcasting centroids to all workers.")
+                future = client.scatter(self.centroids_, broadcast=True)
+                self.centroids_ = da.from_delayed(
+                    future,
+                    shape=self.centroids_.shape,
+                    dtype=self.centroids_.dtype,
+                )
+                client.rebalance()
 
-            distance = float(self.average_min_distance)
+            distance = self.average_min_distance
 
             logger.debug(
                 f"Average minimal squared Euclidean distance = {distance}"
@@ -267,18 +349,6 @@ class KMeansMachine(BaseEstimator):
             logger.info(
                 "Reached maximum step. Training stopped without convergence."
             )
-        self.compute()
-        return self
-
-    def partial_fit(self, X, y=None):
-        """Applies one e-m step of k-means on the data."""
-        if self.centroids_ is None:
-            logger.debug("First call to 'partial_fit'. Initializing...")
-            self.initialize(data=X)
-
-        self.e_step(data=X)
-        self.m_step(data=X)
-
         return self
 
     def transform(self, X):
@@ -294,7 +364,7 @@ class KMeansMachine(BaseEstimator):
         distances: ndarray of shape (n_clusters, n_samples)
             For each mean, for each point, the squared Euclidian distance between them.
         """
-        return self.get_centroids_distance(X)
+        return get_centroids_distance(X, self.centroids_)
 
     def predict(self, X):
         """Returns the labels of the closest cluster centroid to the data.
@@ -309,9 +379,6 @@ class KMeansMachine(BaseEstimator):
         indices: ndarray of shape (n_samples)
             The indices of the closest cluster for each data point.
         """
-        return self.get_closest_centroid_index(X)
-
-    def compute(self, *args, **kwargs):
-        """Computes delayed arrays if needed."""
-        for name in ("centroids_",):
-            setattr(self, name, np.asarray(getattr(self, name)))
+        return get_closest_centroid_index(
+            get_centroids_distance(X, self.centroids_)
+        )
diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py
index df066c8..539d20a 100644
--- a/bob/learn/em/test/test_kmeans.py
+++ b/bob/learn/em/test/test_kmeans.py
@@ -30,7 +30,10 @@ def to_numpy(*args):
 def to_dask_array(*args):
     result = []
     for x in args:
-        result.append(da.from_array(np.array(x)))
+        x = np.asarray(x)
+        chunks = list(x.shape)
+        chunks[0] //= 2
+        result.append(da.from_array(x, chunks=chunks))
     if len(result) == 1:
         return result[0]
     return result
@@ -54,23 +57,12 @@ def test_KMeansMachine():
         np.testing.assert_equal(km.transform(test_val)[0], np.array([1]))
         np.testing.assert_equal(km.transform(test_val)[1], np.array([6]))
 
-        (index, dist) = km.get_closest_centroid(test_val)
-        assert index == 0
-        np.testing.assert_equal(dist, np.array([[1.0]]))
-
-        (indices, dists) = km.get_closest_centroid(test_arr)
-        np.testing.assert_equal(indices, np.array([0, 1]))
-        np.testing.assert_equal(dists, np.array([[1, 8], [6, 1]]))
-
         index = km.predict(test_val)
         assert index == 0
 
         indices = km.predict(test_arr)
         np.testing.assert_equal(indices, np.array([0, 1]))
 
-        np.testing.assert_equal(km.get_min_distance(test_val), np.array([1]))
-        np.testing.assert_equal(km.get_min_distance(test_arr), np.array([1, 1]))
-
         # Check __eq__ and is_similar_to
         km2 = KMeansMachine(2)
         assert km != km2
@@ -124,25 +116,6 @@ def test_kmeans_fit():
         machine.fit(data)
 
 
-def test_kmeans_fit_partial():
-    np.random.seed(0)
-    data1 = np.random.normal(loc=1, size=(2000, 3))
-    data2 = np.random.normal(loc=-1, size=(2000, 3))
-    data = np.concatenate([data1, data2], axis=0)
-
-    for transform in (to_numpy, to_dask_array):
-        data = transform(data)
-        machine = KMeansMachine(2, random_state=0)
-        for _ in range(20):
-            machine.partial_fit(data)
-        centroids = machine.centroids_[np.argsort(machine.centroids_[:, 0])]
-        expected = [
-            [-1.07173464, -1.06200356, -1.00724920],
-            [0.99479125, 0.99665564, 0.97689017],
-        ]
-        np.testing.assert_almost_equal(centroids, expected, decimal=7)
-
-
 def test_kmeans_fit_init_pp():
     np.random.seed(0)
     data1 = np.random.normal(loc=1, size=(2000, 3))
-- 
GitLab