From ad7bfb1892cce15fd7a23c309dfe02462c28b91c Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 23 Mar 2022 12:13:19 +0100 Subject: [PATCH] improve variance estimation speed in kmeans, convert data to proper arrays --- bob/learn/em/gmm.py | 10 ++++++++-- bob/learn/em/kmeans.py | 20 +++++++++++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 78ab69f..ac850dd 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -672,6 +672,9 @@ class GMMMachine(BaseEstimator): ) kmeans_machine = kmeans_machine.fit(data) + logger.debug( + "Estimating the variance and weights of each gaussian from kmeans." + ) ( variances, weights, @@ -680,6 +683,7 @@ class GMMMachine(BaseEstimator): # Set the GMM machine's gaussians with the results of k-means self.means = copy.deepcopy(kmeans_machine.centroids_) self.variances, self.weights = dask.compute(variances, weights) + logger.debug("Done.") def log_weighted_likelihood( self, @@ -833,7 +837,7 @@ class GMMMachine(BaseEstimator): def fit(self, X, y=None): """Trains the GMM on data until convergence or maximum step is reached.""" - input_is_dask = check_and_persist_dask_input(X) + input_is_dask, X = check_and_persist_dask_input(X) if self._means is None: self.initialize_gaussians(X) @@ -912,7 +916,9 @@ class GMMMachine(BaseEstimator): (average_output_previous - average_output) / average_output_previous ) - logger.debug(f"convergence val = {convergence_value}") + logger.debug( + f"convergence val = {convergence_value} and threshold = {self.convergence_threshold}" + ) # Terminates if converged (and likelihood computation is set) if ( diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py index 329cc7c..b5b6094 100644 --- a/bob/learn/em/kmeans.py +++ b/bob/learn/em/kmeans.py @@ -126,7 +126,16 @@ def check_and_persist_dask_input(data): if isinstance(data, da.Array): data: da.Array = data.persist() input_is_dask = True - return input_is_dask + # if there is a dask distributed client, rebalance data + try: + client = dask.distributed.Client.current() + client.rebalance() + except ValueError: + pass + + else: + data = np.asarray(data) + return input_is_dask, data def array_to_delayed_list(data, input_is_dask): @@ -241,7 +250,8 @@ class KMeansMachine(BaseEstimator): weights: ndarray of shape (n_clusters, ) Weight (proportion of quantity of data point) of each cluster. """ - n_clusters = self.n_clusters + _, data = check_and_persist_dask_input(data) + n_clusters, n_features = self.n_clusters, data.shape[1] dist = get_centroids_distance(data, self.centroids_) closest_centroid_indices = get_closest_centroid_index(dist) weights_count = np.bincount( @@ -250,10 +260,10 @@ class KMeansMachine(BaseEstimator): weights = weights_count / weights_count.sum() # Accumulate - means_sum = np.zeros((n_clusters, data.shape[1])) + means_sum = np.zeros((n_clusters, n_features), like=data) + variances_sum = np.zeros((n_clusters, n_features), like=data) for i in range(n_clusters): means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0) - variances_sum = np.zeros((n_clusters, data.shape[1])) for i in range(n_clusters): variances_sum[i] = np.sum( data[closest_centroid_indices == i] ** 2, axis=0 @@ -282,7 +292,7 @@ class KMeansMachine(BaseEstimator): def fit(self, X, y=None): """Fits this machine on data samples.""" - input_is_dask = check_and_persist_dask_input(X) + input_is_dask, X = check_and_persist_dask_input(X) logger.debug("Initializing trainer.") self.initialize(data=X) -- GitLab