diff --git a/bob/learn/em/k_means.py b/bob/learn/em/k_means.py index 52f025676e481b51714a6c5c3aca62a145e6ab4b..329cc7cc47798d70eb0880147e8101b3f83a1ae9 100644 --- a/bob/learn/em/k_means.py +++ b/bob/learn/em/k_means.py @@ -74,10 +74,10 @@ def e_step(data, means): zeroeth_order_statistics = np.bincount( closest_k_indices, minlength=n_clusters ) - first_order_statistics = np.sum( - np.eye(n_clusters)[closest_k_indices][:, :, None] * data[:, None], - axis=0, - ) + # Compute first_order_statistics in a memory efficient way + first_order_statistics = np.zeros((n_clusters, data.shape[1])) + for i in range(n_clusters): + first_order_statistics[i] = np.sum(data[closest_k_indices == i], axis=0) min_distance = np.min(distances, axis=0) average_min_distance = min_distance.mean() return ( @@ -241,28 +241,23 @@ class KMeansMachine(BaseEstimator): weights: ndarray of shape (n_clusters, ) Weight (proportion of quantity of data point) of each cluster. """ - n_cluster = self.n_clusters + n_clusters = self.n_clusters dist = get_centroids_distance(data, self.centroids_) closest_centroid_indices = get_closest_centroid_index(dist) weights_count = np.bincount( - closest_centroid_indices, minlength=n_cluster + closest_centroid_indices, minlength=n_clusters ) weights = weights_count / weights_count.sum() - # FIX for `too many indices for array` error if using `np.eye(n_cluster)` alone: - dask_compatible_eye = np.eye(n_cluster) * np.array(1, like=data) - # Accumulate - means_sum = np.sum( - dask_compatible_eye[closest_centroid_indices][:, :, None] - * data[:, None], - axis=0, - ) - variances_sum = np.sum( - dask_compatible_eye[closest_centroid_indices][:, :, None] - * (data[:, None] ** 2), - axis=0, - ) + means_sum = np.zeros((n_clusters, data.shape[1])) + for i in range(n_clusters): + means_sum[i] = np.sum(data[closest_centroid_indices == i], axis=0) + variances_sum = np.zeros((n_clusters, data.shape[1])) + for i in range(n_clusters): + variances_sum[i] = np.sum( + data[closest_centroid_indices == i] ** 2, axis=0 + ) # Reduce means = means_sum / weights_count[:, None]