From 212c74672912b36c6e331e36a2f98b8f36ed1c89 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 4 Apr 2022 14:40:03 +0200
Subject: [PATCH] [kmeans] parallelize kmeans weights variance estimation

---
 bob/learn/em/gmm.py    | 11 ++----
 bob/learn/em/kmeans.py | 90 +++++++++++++++++++++++++++++++-----------
 2 files changed, 70 insertions(+), 31 deletions(-)

diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index e4d6172..07c807e 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -719,18 +719,15 @@ class GMMMachine(BaseEstimator):
             )
             kmeans_machine = kmeans_machine.fit(data)
 
+            # Set the GMM machine's gaussians with the results of k-means
+            self.means = copy.deepcopy(kmeans_machine.centroids_)
             logger.debug(
                 "Estimating the variance and weights of each gaussian from kmeans."
             )
             (
-                variances,
-                weights,
+                self.variances,
+                self.weights,
             ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data)
-
-            # Set the GMM machine's gaussians with the results of k-means
-            self.means = copy.deepcopy(kmeans_machine.centroids_)
-            # TODO: remove this compute
-            self.variances, self.weights = dask.compute(variances, weights)
             logger.debug("Done.")
 
     def log_weighted_likelihood(
diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py
index b48fa28..c1b7153 100644
--- a/bob/learn/em/kmeans.py
+++ b/bob/learn/em/kmeans.py
@@ -46,7 +46,18 @@ def get_centroids_distance(x: np.ndarray, means: np.ndarray) -> np.ndarray:
 
 
 def get_closest_centroid_index(centroids_dist: np.ndarray) -> np.ndarray:
-    """Returns the index of the closest cluster mean to x."""
+    """Returns the index of the closest cluster mean to x.
+
+    Parameters
+    ----------
+    centroids_dist: ndarray of shape (n_clusters, n_samples)
+        The squared Euclidian distance (or distances) to each cluster mean.
+
+    Returns
+    -------
+    closest_centroid_indices: ndarray of shape (n_samples,)
+        The index of the closest cluster mean to x.
+    """
     return np.argmin(centroids_dist, axis=0)
 
 
@@ -123,6 +134,43 @@ def m_step(stats, n_samples):
     return means, average_min_distance
 
 
+def accumulate_indices_means_vars(data, means):
+    """Accumulates statistics needed to compute weights and variances of the clusters."""
+    n_clusters, n_features = len(means), data.shape[1]
+    dist = get_centroids_distance(data, means)
+    closest_centroid_indices = get_closest_centroid_index(dist)
+    # the means_sum and variances_sum must be initialized with zero here since
+    # they get accumulated in the next function
+    means_sum = np.zeros((n_clusters, n_features), like=data)
+    variances_sum = np.zeros((n_clusters, n_features), like=data)
+    for i in range(n_clusters):
+        means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0)
+    for i in range(n_clusters):
+        variances_sum[i] = np.sum(
+            data[closest_centroid_indices == i] ** 2, axis=0
+        )
+    return closest_centroid_indices, means_sum, variances_sum
+
+
+def reduce_indices_means_vars(stats):
+    """Computes weights and variances of the clusters given the statistics."""
+    closest_centroid_indices = [s[0] for s in stats]
+    means_sum = [s[1] for s in stats]
+    variances_sum = [s[2] for s in stats]
+
+    closest_centroid_indices = np.concatenate(closest_centroid_indices, axis=0)
+    means_sum = np.sum(means_sum, axis=0)
+    variances_sum = np.sum(variances_sum, axis=0)
+
+    n_clusters = len(means_sum)
+    weights_count = np.bincount(closest_centroid_indices, minlength=n_clusters)
+    weights = weights_count / weights_count.sum()
+    means = means_sum / weights_count[:, None]
+    variances = (variances_sum / weights_count[:, None]) - (means ** 2)
+
+    return variances, weights
+
+
 def check_and_persist_dask_input(data):
     # check if input is a dask array. If so, persist and rebalance data
     input_is_dask = False
@@ -253,30 +301,24 @@ class KMeansMachine(BaseEstimator):
             weights: ndarray of shape (n_clusters, )
                 Weight (proportion of quantity of data point) of each cluster.
         """
-        _, data = check_and_persist_dask_input(data)
-
-        # TODO: parallelize this like e_step
-        # Accumulate
-        dist = get_centroids_distance(data, self.centroids_)
-        closest_centroid_indices = get_closest_centroid_index(dist)
+        input_is_dask, data = check_and_persist_dask_input(data)
+        data = array_to_delayed_list(data, input_is_dask)
 
-        means_sum, variances_sum = [], []
-        for i in range(self.n_clusters):
-            cluster_data = data[closest_centroid_indices == i]
-            means_sum.append(np.sum(cluster_data, axis=0))
-            variances_sum.append(np.sum(cluster_data ** 2, axis=0))
-
-        means_sum, variances_sum = np.vstack(means_sum), np.vstack(
-            variances_sum
-        )
-
-        # Reduce (similar to m_step)
-        weights_count = np.bincount(
-            closest_centroid_indices, minlength=self.n_clusters
-        )
-        weights = weights_count / weights_count.sum()
-        means = means_sum / weights_count[:, None]
-        variances = (variances_sum / weights_count[:, None]) - (means ** 2)
+        if input_is_dask:
+            stats = [
+                dask.delayed(accumulate_indices_means_vars)(
+                    xx, means=self.centroids_
+                )
+                for xx in data
+            ]
+            variances, weights = dask.compute(
+                dask.delayed(reduce_indices_means_vars)(stats)
+            )[0]
+        else:
+            # Accumulate
+            stats = accumulate_indices_means_vars(data, self.centroids_)
+            # Reduce
+            variances, weights = reduce_indices_means_vars([stats])
 
         return variances, weights
 
-- 
GitLab