From 65d1a837a063316e78cbb9870c090ed1c08c463c Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Tue, 15 Mar 2022 14:53:42 +0100 Subject: [PATCH] [gmm] split e-m training into two clear steps --- bob/learn/em/gmm.py | 170 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 149 insertions(+), 21 deletions(-) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index ba0ccb6..436d169 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -5,23 +5,123 @@ """This module provides classes and functions for the training and usage of GMM.""" import copy +import functools import logging +import operator from typing import Union +import dask import dask.array as da import numpy as np from h5py import File as HDF5File from sklearn.base import BaseEstimator -from .k_means import KMeansMachine +from .k_means import ( + KMeansMachine, + array_to_delayed_list, + check_and_persist_dask_input, +) logger = logging.getLogger(__name__) EPSILON = np.finfo(float).eps +def logaddexp_reduce(array, axis=0, keepdims=False): + return np.logaddexp.reduce( + array, axis=axis, keepdims=keepdims, initial=-np.inf + ) + + +def e_step(data, weights, means, variances, g_norms, log_weights): + # Ensure data is a series of samples (2D array) + data = np.atleast_2d(data) + + n_gaussians = len(weights) + + # Allow the absence of previous statistics + statistics = GMMStats(n_gaussians, data.shape[-1]) + + # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)] + z = np.empty_like(data, shape=(n_gaussians, len(data))) + for i in range(n_gaussians): + z[i] = np.sum((data - means[i]) ** 2 / variances[i], axis=-1) + ll = -0.5 * (g_norms[:, None] + z) + log_weighted_likelihoods = log_weights[:, None] + ll + + # Log likelihood [array of shape (n_samples,)] + if isinstance(log_weighted_likelihoods, np.ndarray): + log_likelihood = logaddexp_reduce(log_weighted_likelihoods) + else: + # Sum along gaussians axis (using logAddExp to prevent underflow) + log_likelihood = da.reduction( + x=log_weighted_likelihoods, + chunk=logaddexp_reduce, + aggregate=logaddexp_reduce, + axis=0, + dtype=float, + keepdims=False, + ) + + # Responsibility P [array of shape (n_gaussians, n_samples)] + responsibility = np.exp(log_weighted_likelihoods - log_likelihood[None, :]) + + # Accumulate + + # Total likelihood [float] + statistics.log_likelihood += log_likelihood.sum() + # Count of samples [int] + statistics.t += data.shape[0] + # Responsibilities [array of shape (n_gaussians,)] + statistics.n = statistics.n + responsibility.sum(axis=-1) + for i in range(n_gaussians): + # p * x [array of shape (n_gaussians, n_samples, n_features)] + px = responsibility[i, :, None] * data + # First order stats [array of shape (n_gaussians, n_features)] + statistics.sum_px[i] = statistics.sum_px[i] + np.sum(px, axis=0) + # Second order stats [array of shape (n_gaussians, n_features)] + statistics.sum_pxx[i] = statistics.sum_pxx[i] + np.sum( + px * data, axis=0 + ) + + # px = np.multiply(responsibility[:, :, None], data[None, :, :]) + # statistics.sum_px = statistics.sum_px + px.sum(axis=1) + # pxx = np.multiply(px[:, :, :], data[None, :, :]) + # statistics.sum_pxx = statistics.sum_pxx + pxx.sum(axis=1) + + return statistics + + +def m_step( + machine, + statistics, + update_means, + update_variances, + update_weights, + mean_var_update_threshold, + map_relevance_factor, + map_alpha, + trainer, +): + m_step_func = map_gmm_m_step if trainer == "map" else ml_gmm_m_step + statistics = functools.reduce(operator.iadd, statistics) + m_step_func( + machine, + statistics=statistics, + update_means=update_means, + update_variances=update_variances, + update_weights=update_weights, + mean_var_update_threshold=mean_var_update_threshold, + reynolds_adaptation=map_relevance_factor is not None, + alpha=map_alpha, + relevance_factor=map_relevance_factor, + ) + average_output = float(statistics.log_likelihood / statistics.t) + return average_output + + class GMMStats: """Stores accumulated statistics of a GMM. @@ -403,9 +503,7 @@ class GMMMachine(BaseEstimator): self._variances = np.maximum(self.variance_thresholds, variances) # Recompute g_norm for each gaussian [array of shape (n_gaussians,)] n_log_2pi = self._variances.shape[-1] * np.log(2 * np.pi) - self._g_norms = np.array( - n_log_2pi + np.log(self._variances).sum(axis=-1) - ) + self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1) @property def variance_thresholds(self): @@ -580,9 +678,8 @@ class GMMMachine(BaseEstimator): ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data) # Set the GMM machine's gaussians with the results of k-means - self.means = np.array(copy.deepcopy(kmeans_machine.centroids_)) - self.variances = np.array(copy.deepcopy(variances)) - self.weights = np.array(copy.deepcopy(weights)) + self.means = copy.deepcopy(kmeans_machine.centroids_) + self.variances, self.weights = dask.compute(variances, weights) def log_weighted_likelihood( self, @@ -735,6 +832,9 @@ class GMMMachine(BaseEstimator): def fit(self, X, y=None): """Trains the GMM on data until convergence or maximum step is reached.""" + + input_is_dask = check_and_persist_dask_input(X) + if self._means is None: self.initialize_gaussians(X) else: @@ -746,6 +846,19 @@ class GMMMachine(BaseEstimator): ) self.variances = np.ones_like(self.means) + m_step_func = functools.partial( + m_step, + update_means=self.update_means, + update_variances=self.update_variances, + update_weights=self.update_weights, + mean_var_update_threshold=self.mean_var_update_threshold, + map_relevance_factor=self.map_relevance_factor, + map_alpha=self.map_alpha, + trainer=self.trainer, + ) + + X = array_to_delayed_list(X, input_is_dask) + average_output = 0 logger.info("Training GMM...") step = 0 @@ -761,23 +874,37 @@ class GMMMachine(BaseEstimator): ) average_output_previous = average_output - stats = self.e_step(X) - self.m_step( - stats=stats, - ) - # if we're running in dask, persist weights, means, and variances so - # we don't recompute each step. - for attr in ["weights", "means", "variances"]: - arr = getattr(self, attr) - if isinstance(arr, da.Array): - setattr(self, attr, arr.persist()) + # compute the e-m steps + if input_is_dask: + stats = [ + dask.delayed(e_step)( + data=xx, + weights=self.weights, + means=self.means, + variances=self.variances, + g_norms=self.g_norms, + log_weights=self.log_weights, + ) + for xx in X + ] + average_output = dask.compute( + dask.delayed(m_step_func)(self, stats) + )[0] + else: + stats = [ + e_step( + data=X, + weights=self.weights, + means=self.means, + variances=self.variances, + g_norms=self.g_norms, + log_weights=self.log_weights, + ) + ] + average_output = m_step_func(self, stats) - # Note: Uses the stats from before m_step, leading to an additional m_step - # (which is not bad because it will always converge) - average_output = float(stats.log_likelihood / stats.t) logger.debug(f"log likelihood = {average_output}") - if step > 1: convergence_value = abs( (average_output_previous - average_output) @@ -794,6 +921,7 @@ class GMMMachine(BaseEstimator): "Reached convergence threshold. Training stopped." ) break + else: logger.info( "Reached maximum step. Training stopped without convergence." -- GitLab