diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index ddf985321c3bffbe298218e23512d962a5d2d98a..762939981b58a90379ceac1d7b0e4073ba8d3f5d 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -35,21 +35,15 @@ def logaddexp_reduce(array, axis=0, keepdims=False): ) -def log_weighted_likelihood(data, means, variances, g_norms, log_weights): +def log_weighted_likelihood(data, machine): """Returns the weighted log likelihood for each Gaussian for a set of data. Parameters ---------- data Data to compute the log likelihood on. - means - Means of the Gaussians. - variances - Variances of the Gaussians. - g_norms - Normalization factors of the Gaussians. - log_weights - Log weights of the Gaussians. + machine + The GMM machine. Returns ------- @@ -57,12 +51,14 @@ def log_weighted_likelihood(data, means, variances, g_norms, log_weights): The weighted log likelihood of each sample of each Gaussian. """ # Compute the likelihood for each data point on each Gaussian - n_gaussians, n_samples = len(means), len(data) + n_gaussians, n_samples = len(machine.means), len(data) z = np.empty(shape=(n_gaussians, n_samples), like=data) for i in range(n_gaussians): - z[i] = np.sum((data - means[i]) ** 2 / variances[i], axis=-1) - ll = -0.5 * (g_norms[:, None] + z) - log_weighted_likelihoods = log_weights[:, None] + ll + z[i] = np.sum( + (data - machine.means[i]) ** 2 / machine.variances[i], axis=-1 + ) + ll = -0.5 * (machine.g_norms[:, None] + z) + log_weighted_likelihoods = machine.log_weights[:, None] + ll return log_weighted_likelihoods @@ -82,13 +78,15 @@ def reduce_loglikelihood(log_weighted_likelihoods): return log_likelihood -def log_likelihood(data, means, variances, g_norms, log_weights): +def log_likelihood(data, machine): """Returns the current log likelihood for a set of data in this Machine. Parameters ---------- data Data to compute the log likelihood on. + machine + The GMM machine. Returns ------- @@ -100,29 +98,26 @@ def log_likelihood(data, means, variances, g_norms, log_weights): # All likelihoods [array of shape (n_gaussians, n_samples)] log_weighted_likelihoods = log_weighted_likelihood( data=data, - means=means, - variances=variances, - g_norms=g_norms, - log_weights=log_weights, + machine=machine, ) # Likelihoods of each sample on this machine. [array of shape (n_samples,)] ll_reduced = reduce_loglikelihood(log_weighted_likelihoods) return ll_reduced -def e_step(data, weights, means, variances, g_norms, log_weights): +def e_step(data, machine): """Expectation step of the e-m algorithm.""" # Ensure data is a series of samples (2D array) data = np.atleast_2d(data) - n_gaussians = len(weights) + n_gaussians = len(machine.weights) # Allow the absence of previous statistics statistics = GMMStats(n_gaussians, data.shape[-1]) # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)] log_weighted_likelihoods = log_weighted_likelihood( - data, means, variances, g_norms, log_weights + data=data, machine=machine ) # Log likelihood [array of shape (n_samples,)] @@ -153,29 +148,22 @@ def e_step(data, weights, means, variances, g_norms, log_weights): def m_step( - machine, statistics, - update_means, - update_variances, - update_weights, - mean_var_update_threshold, - map_relevance_factor, - map_alpha, - trainer, + machine, ): """Maximization step of the e-m algorithm.""" - m_step_func = map_gmm_m_step if trainer == "map" else ml_gmm_m_step + m_step_func = map_gmm_m_step if machine.trainer == "map" else ml_gmm_m_step statistics = functools.reduce(operator.iadd, statistics) m_step_func( machine=machine, statistics=statistics, - update_means=update_means, - update_variances=update_variances, - update_weights=update_weights, - mean_var_update_threshold=mean_var_update_threshold, - reynolds_adaptation=map_relevance_factor is not None, - alpha=map_alpha, - relevance_factor=map_relevance_factor, + update_means=machine.update_means, + update_variances=machine.update_variances, + update_weights=machine.update_weights, + mean_var_update_threshold=machine.mean_var_update_threshold, + reynolds_adaptation=machine.map_relevance_factor is not None, + alpha=machine.map_alpha, + relevance_factor=machine.map_relevance_factor, ) average_output = float(statistics.log_likelihood / statistics.t) return machine, average_output @@ -762,10 +750,7 @@ class GMMMachine(BaseEstimator): """ return log_weighted_likelihood( data=data, - means=self.means, - variances=self.variances, - g_norms=self.g_norms, - log_weights=self.log_weights, + machine=self, ) def log_likelihood( @@ -786,10 +771,7 @@ class GMMMachine(BaseEstimator): """ return log_likelihood( data=data, - means=self.means, - variances=self.variances, - g_norms=self.g_norms, - log_weights=self.log_weights, + machine=self, ) def fit(self, X, y=None): @@ -808,17 +790,6 @@ class GMMMachine(BaseEstimator): ) self.variances = np.ones_like(self.means) - m_step_func = functools.partial( - m_step, - update_means=self.update_means, - update_variances=self.update_variances, - update_weights=self.update_weights, - mean_var_update_threshold=self.mean_var_update_threshold, - map_relevance_factor=self.map_relevance_factor, - map_alpha=self.map_alpha, - trainer=self.trainer, - ) - X = array_to_delayed_list(X, input_is_dask) average_output = 0 @@ -842,16 +813,12 @@ class GMMMachine(BaseEstimator): stats = [ dask.delayed(e_step)( data=xx, - weights=self.weights, - means=self.means, - variances=self.variances, - g_norms=self.g_norms, - log_weights=self.log_weights, + machine=self, ) for xx in X ] new_machine, average_output = dask.compute( - dask.delayed(m_step_func)(self, stats) + dask.delayed(m_step)(stats, self) )[0] for attr in ["weights", "means", "variances"]: setattr(self, attr, getattr(new_machine, attr)) @@ -859,14 +826,10 @@ class GMMMachine(BaseEstimator): stats = [ e_step( data=X, - weights=self.weights, - means=self.means, - variances=self.variances, - g_norms=self.g_norms, - log_weights=self.log_weights, + machine=self, ) ] - _, average_output = m_step_func(self, stats) + _, average_output = m_step(stats, self) logger.debug(f"log likelihood = {average_output}") if step > 1: @@ -897,12 +860,8 @@ class GMMMachine(BaseEstimator): def transform(self, X, **kwargs): """Returns the statistics for `X`.""" return e_step( - X, - self.weights, - self.means, - self.variances, - self.g_norms, - self.log_weights, + data=X, + machine=self, ) def _more_tags(self): diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 65743ce3b95ff9bfa6a5712800a023942316c22f..6e4b5d0aea309140a081fbf91a2b92b9bc732935 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -286,11 +286,7 @@ def test_GMMMachine_stats(): stats = gmm_module.e_step( arrayset, - gmm.weights, - gmm.means, - gmm.variances, - gmm.g_norms, - gmm.log_weights, + gmm, ) stats_ref = GMMStats(n_gaussians=2, n_features=2) @@ -365,11 +361,7 @@ def test_GMMStats_operations(): # Populate the GMMStats stats = gmm_module.e_step( data, - machine.weights, - machine.means, - machine.variances, - machine.g_norms, - machine.log_weights, + machine, ) # Check shapes @@ -618,22 +610,11 @@ def test_ml_em(): stats = gmm_module.e_step( data, - machine.weights, - machine.means, - machine.variances, - machine.g_norms, - machine.log_weights, + machine, ) gmm_module.m_step( - machine, [stats], - machine.update_means, - machine.update_variances, - machine.update_weights, - machine.mean_var_update_threshold, - machine.map_relevance_factor, - machine.map_alpha, - machine.trainer, + machine, ) expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]]) @@ -674,22 +655,11 @@ def test_map_em(): stats = gmm_module.e_step( post_data, - machine.weights, - machine.means, - machine.variances, - machine.g_norms, - machine.log_weights, + machine, ) gmm_module.m_step( - machine, [stats], - machine.update_means, - machine.update_variances, - machine.update_weights, - machine.mean_var_update_threshold, - machine.map_relevance_factor, - machine.map_alpha, - machine.trainer, + machine, ) expected_means = np.array(