diff --git a/bob/learn/em/k_means.py b/bob/learn/em/k_means.py index 5d7c95e16f9d81af971ce61fb1d0a4cdc8731903..52f025676e481b51714a6c5c3aca62a145e6ab4b 100644 --- a/bob/learn/em/k_means.py +++ b/bob/learn/em/k_means.py @@ -8,10 +8,9 @@ from typing import Union import dask import dask.array as da -import dask.bag import dask.delayed -import distributed import numpy as np +import scipy.spatial.distance from dask_ml.cluster.k_means import k_init from sklearn.base import BaseEstimator @@ -26,15 +25,21 @@ def get_centroids_distance(x: np.ndarray, means: np.ndarray) -> np.ndarray: Parameters ---------- - x: ndarray of shape (n_features,) or (n_samples, n_features) - One data point, or a series of data points. + x: ndarray of shape (n_samples, n_features) + A series of data points. + means: ndarray of shape (n_clusters, n_features) + The centroids. Returns ------- - distances: ndarray of shape (n_clusters,) or (n_clusters, n_samples) + distances: ndarray of shape (n_clusters, n_samples) For each cluster, the squared Euclidian distance (or distances) to x. """ - return np.sum((means[:, None] - x[None, :]) ** 2, axis=-1) + x = np.atleast_2d(x) + if isinstance(x, da.Array): + return np.sum((means[:, None] - x[None, :]) ** 2, axis=-1) + else: + return scipy.spatial.distance.cdist(means, x, metric="sqeuclidean") def get_closest_centroid_index(centroids_dist: np.ndarray) -> np.ndarray: @@ -115,6 +120,23 @@ def m_step(stats, n_samples): return means, average_min_distance +def check_and_persist_dask_input(data): + # check if input is a dask array. If so, persist and rebalance data + input_is_dask = False + if isinstance(data, da.Array): + data: da.Array = data.persist() + input_is_dask = True + return input_is_dask + + +def array_to_delayed_list(data, input_is_dask): + # If input is a dask array, convert to delayed chunks + if input_is_dask: + data = data.to_delayed().ravel().tolist() + logger.debug(f"Got {len(data)} chunks.") + return data + + class KMeansMachine(BaseEstimator): """Stores the k-means clusters parameters (centroid of each cluster). @@ -265,27 +287,14 @@ class KMeansMachine(BaseEstimator): def fit(self, X, y=None): """Fits this machine on data samples.""" - # check if input is a dask array. If so, persist and rebalance data - client, input_is_dask = None, False - if isinstance(X, da.Array): - X: da.Array = X.persist() - input_is_dask = True - - try: - client = distributed.Client.current() - client.rebalance() - except ValueError: - pass + input_is_dask = check_and_persist_dask_input(X) logger.debug("Initializing trainer.") self.initialize(data=X) n_samples = len(X) - # If input is a dask array, convert to delayed chunks - if input_is_dask: - X = X.to_delayed() - logger.debug(f"Got {len(X)} chunks.") + X = array_to_delayed_list(X, input_is_dask) logger.info("Training k-means.") distance = np.inf @@ -301,8 +310,7 @@ class KMeansMachine(BaseEstimator): # compute the e-m steps if input_is_dask: stats = [ - dask.delayed(e_step)(xx, means=self.centroids_) - for xx in X.ravel().tolist() + dask.delayed(e_step)(xx, means=self.centroids_) for xx in X ] self.centroids_, self.average_min_distance = dask.compute( dask.delayed(m_step)(stats, n_samples) @@ -313,17 +321,6 @@ class KMeansMachine(BaseEstimator): stats, n_samples ) - # scatter centroids to all workers for efficiency - if client is not None: - logger.debug("Broadcasting centroids to all workers.") - future = client.scatter(self.centroids_, broadcast=True) - self.centroids_ = da.from_delayed( - future, - shape=self.centroids_.shape, - dtype=self.centroids_.dtype, - ) - client.rebalance() - distance = self.average_min_distance logger.debug( @@ -345,6 +342,7 @@ class KMeansMachine(BaseEstimator): "Reached convergence threshold. Training stopped." ) break + else: logger.info( "Reached maximum step. Training stopped without convergence." diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py index 539d20a950c132087fde9ef60573717fd44ed272..1400b469e909973e3a7911c7e7439a55b7d87a47 100644 --- a/bob/learn/em/test/test_kmeans.py +++ b/bob/learn/em/test/test_kmeans.py @@ -14,8 +14,9 @@ import copy import dask.array as da import numpy as np +import scipy.spatial.distance -from bob.learn.em import KMeansMachine +from bob.learn.em import KMeansMachine, k_means def to_numpy(*args): @@ -174,3 +175,18 @@ def test_kmeans_parameters(): [0.99479125, 0.99665564, 0.97689017], ] np.testing.assert_almost_equal(centroids, expected, decimal=7) + + +def test_get_centroids_distance(): + np.random.seed(0) + n_features = 60 + n_samples = 240_000 + n_clusters = 256 + data = np.random.normal(loc=1, size=(n_samples, n_features)) + means = np.random.normal(loc=-1, size=(n_clusters, n_features)) + oracle = scipy.spatial.distance.cdist(means, data, metric="sqeuclidean") + for transform in (to_numpy,): + data, means = transform(data, means) + dist = k_means.get_centroids_distance(data, means) + np.testing.assert_allclose(dist, oracle) + assert type(data) is type(dist), (type(data), type(dist))