Skip to content
Snippets Groups Projects

improve variance estimation speed in kmeans, convert data to proper arrays

Merged Amir MOHAMMADI requested to merge improve-variance-estimation-speed into master
2 files
+ 23
7
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 8
2
@@ -672,6 +672,9 @@ class GMMMachine(BaseEstimator):
@@ -672,6 +672,9 @@ class GMMMachine(BaseEstimator):
)
)
kmeans_machine = kmeans_machine.fit(data)
kmeans_machine = kmeans_machine.fit(data)
 
logger.debug(
 
"Estimating the variance and weights of each gaussian from kmeans."
 
)
(
(
variances,
variances,
weights,
weights,
@@ -680,6 +683,7 @@ class GMMMachine(BaseEstimator):
@@ -680,6 +683,7 @@ class GMMMachine(BaseEstimator):
# Set the GMM machine's gaussians with the results of k-means
# Set the GMM machine's gaussians with the results of k-means
self.means = copy.deepcopy(kmeans_machine.centroids_)
self.means = copy.deepcopy(kmeans_machine.centroids_)
self.variances, self.weights = dask.compute(variances, weights)
self.variances, self.weights = dask.compute(variances, weights)
 
logger.debug("Done.")
def log_weighted_likelihood(
def log_weighted_likelihood(
self,
self,
@@ -833,7 +837,7 @@ class GMMMachine(BaseEstimator):
@@ -833,7 +837,7 @@ class GMMMachine(BaseEstimator):
def fit(self, X, y=None):
def fit(self, X, y=None):
"""Trains the GMM on data until convergence or maximum step is reached."""
"""Trains the GMM on data until convergence or maximum step is reached."""
input_is_dask = check_and_persist_dask_input(X)
input_is_dask, X = check_and_persist_dask_input(X)
if self._means is None:
if self._means is None:
self.initialize_gaussians(X)
self.initialize_gaussians(X)
@@ -912,7 +916,9 @@ class GMMMachine(BaseEstimator):
@@ -912,7 +916,9 @@ class GMMMachine(BaseEstimator):
(average_output_previous - average_output)
(average_output_previous - average_output)
/ average_output_previous
/ average_output_previous
)
)
logger.debug(f"convergence val = {convergence_value}")
logger.debug(
 
f"convergence val = {convergence_value} and threshold = {self.convergence_threshold}"
 
)
# Terminates if converged (and likelihood computation is set)
# Terminates if converged (and likelihood computation is set)
if (
if (
Loading