From 212c74672912b36c6e331e36a2f98b8f36ed1c89 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Mon, 4 Apr 2022 14:40:03 +0200 Subject: [PATCH] [kmeans] parallelize kmeans weights variance estimation --- bob/learn/em/gmm.py | 11 ++---- bob/learn/em/kmeans.py | 90 +++++++++++++++++++++++++++++++----------- 2 files changed, 70 insertions(+), 31 deletions(-) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index e4d6172..07c807e 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -719,18 +719,15 @@ class GMMMachine(BaseEstimator): ) kmeans_machine = kmeans_machine.fit(data) + # Set the GMM machine's gaussians with the results of k-means + self.means = copy.deepcopy(kmeans_machine.centroids_) logger.debug( "Estimating the variance and weights of each gaussian from kmeans." ) ( - variances, - weights, + self.variances, + self.weights, ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data) - - # Set the GMM machine's gaussians with the results of k-means - self.means = copy.deepcopy(kmeans_machine.centroids_) - # TODO: remove this compute - self.variances, self.weights = dask.compute(variances, weights) logger.debug("Done.") def log_weighted_likelihood( diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py index b48fa28..c1b7153 100644 --- a/bob/learn/em/kmeans.py +++ b/bob/learn/em/kmeans.py @@ -46,7 +46,18 @@ def get_centroids_distance(x: np.ndarray, means: np.ndarray) -> np.ndarray: def get_closest_centroid_index(centroids_dist: np.ndarray) -> np.ndarray: - """Returns the index of the closest cluster mean to x.""" + """Returns the index of the closest cluster mean to x. + + Parameters + ---------- + centroids_dist: ndarray of shape (n_clusters, n_samples) + The squared Euclidian distance (or distances) to each cluster mean. + + Returns + ------- + closest_centroid_indices: ndarray of shape (n_samples,) + The index of the closest cluster mean to x. + """ return np.argmin(centroids_dist, axis=0) @@ -123,6 +134,43 @@ def m_step(stats, n_samples): return means, average_min_distance +def accumulate_indices_means_vars(data, means): + """Accumulates statistics needed to compute weights and variances of the clusters.""" + n_clusters, n_features = len(means), data.shape[1] + dist = get_centroids_distance(data, means) + closest_centroid_indices = get_closest_centroid_index(dist) + # the means_sum and variances_sum must be initialized with zero here since + # they get accumulated in the next function + means_sum = np.zeros((n_clusters, n_features), like=data) + variances_sum = np.zeros((n_clusters, n_features), like=data) + for i in range(n_clusters): + means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0) + for i in range(n_clusters): + variances_sum[i] = np.sum( + data[closest_centroid_indices == i] ** 2, axis=0 + ) + return closest_centroid_indices, means_sum, variances_sum + + +def reduce_indices_means_vars(stats): + """Computes weights and variances of the clusters given the statistics.""" + closest_centroid_indices = [s[0] for s in stats] + means_sum = [s[1] for s in stats] + variances_sum = [s[2] for s in stats] + + closest_centroid_indices = np.concatenate(closest_centroid_indices, axis=0) + means_sum = np.sum(means_sum, axis=0) + variances_sum = np.sum(variances_sum, axis=0) + + n_clusters = len(means_sum) + weights_count = np.bincount(closest_centroid_indices, minlength=n_clusters) + weights = weights_count / weights_count.sum() + means = means_sum / weights_count[:, None] + variances = (variances_sum / weights_count[:, None]) - (means ** 2) + + return variances, weights + + def check_and_persist_dask_input(data): # check if input is a dask array. If so, persist and rebalance data input_is_dask = False @@ -253,30 +301,24 @@ class KMeansMachine(BaseEstimator): weights: ndarray of shape (n_clusters, ) Weight (proportion of quantity of data point) of each cluster. """ - _, data = check_and_persist_dask_input(data) - - # TODO: parallelize this like e_step - # Accumulate - dist = get_centroids_distance(data, self.centroids_) - closest_centroid_indices = get_closest_centroid_index(dist) + input_is_dask, data = check_and_persist_dask_input(data) + data = array_to_delayed_list(data, input_is_dask) - means_sum, variances_sum = [], [] - for i in range(self.n_clusters): - cluster_data = data[closest_centroid_indices == i] - means_sum.append(np.sum(cluster_data, axis=0)) - variances_sum.append(np.sum(cluster_data ** 2, axis=0)) - - means_sum, variances_sum = np.vstack(means_sum), np.vstack( - variances_sum - ) - - # Reduce (similar to m_step) - weights_count = np.bincount( - closest_centroid_indices, minlength=self.n_clusters - ) - weights = weights_count / weights_count.sum() - means = means_sum / weights_count[:, None] - variances = (variances_sum / weights_count[:, None]) - (means ** 2) + if input_is_dask: + stats = [ + dask.delayed(accumulate_indices_means_vars)( + xx, means=self.centroids_ + ) + for xx in data + ] + variances, weights = dask.compute( + dask.delayed(reduce_indices_means_vars)(stats) + )[0] + else: + # Accumulate + stats = accumulate_indices_means_vars(data, self.centroids_) + # Reduce + variances, weights = reduce_indices_means_vars([stats]) return variances, weights -- GitLab