From 65d1a837a063316e78cbb9870c090ed1c08c463c Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Tue, 15 Mar 2022 14:53:42 +0100
Subject: [PATCH] [gmm] split e-m training into two clear steps

---
 bob/learn/em/gmm.py | 170 ++++++++++++++++++++++++++++++++++++++------
 1 file changed, 149 insertions(+), 21 deletions(-)

diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index ba0ccb6..436d169 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -5,23 +5,123 @@
 """This module provides classes and functions for the training and usage of GMM."""
 
 import copy
+import functools
 import logging
+import operator
 
 from typing import Union
 
+import dask
 import dask.array as da
 import numpy as np
 
 from h5py import File as HDF5File
 from sklearn.base import BaseEstimator
 
-from .k_means import KMeansMachine
+from .k_means import (
+    KMeansMachine,
+    array_to_delayed_list,
+    check_and_persist_dask_input,
+)
 
 logger = logging.getLogger(__name__)
 
 EPSILON = np.finfo(float).eps
 
 
+def logaddexp_reduce(array, axis=0, keepdims=False):
+    return np.logaddexp.reduce(
+        array, axis=axis, keepdims=keepdims, initial=-np.inf
+    )
+
+
+def e_step(data, weights, means, variances, g_norms, log_weights):
+    # Ensure data is a series of samples (2D array)
+    data = np.atleast_2d(data)
+
+    n_gaussians = len(weights)
+
+    # Allow the absence of previous statistics
+    statistics = GMMStats(n_gaussians, data.shape[-1])
+
+    # Log weighted Gaussian likelihoods [array of shape (n_gaussians,n_samples)]
+    z = np.empty_like(data, shape=(n_gaussians, len(data)))
+    for i in range(n_gaussians):
+        z[i] = np.sum((data - means[i]) ** 2 / variances[i], axis=-1)
+    ll = -0.5 * (g_norms[:, None] + z)
+    log_weighted_likelihoods = log_weights[:, None] + ll
+
+    # Log likelihood [array of shape (n_samples,)]
+    if isinstance(log_weighted_likelihoods, np.ndarray):
+        log_likelihood = logaddexp_reduce(log_weighted_likelihoods)
+    else:
+        # Sum along gaussians axis (using logAddExp to prevent underflow)
+        log_likelihood = da.reduction(
+            x=log_weighted_likelihoods,
+            chunk=logaddexp_reduce,
+            aggregate=logaddexp_reduce,
+            axis=0,
+            dtype=float,
+            keepdims=False,
+        )
+
+    # Responsibility P [array of shape (n_gaussians, n_samples)]
+    responsibility = np.exp(log_weighted_likelihoods - log_likelihood[None, :])
+
+    # Accumulate
+
+    # Total likelihood [float]
+    statistics.log_likelihood += log_likelihood.sum()
+    # Count of samples [int]
+    statistics.t += data.shape[0]
+    # Responsibilities [array of shape (n_gaussians,)]
+    statistics.n = statistics.n + responsibility.sum(axis=-1)
+    for i in range(n_gaussians):
+        # p * x [array of shape (n_gaussians, n_samples, n_features)]
+        px = responsibility[i, :, None] * data
+        # First order stats [array of shape (n_gaussians, n_features)]
+        statistics.sum_px[i] = statistics.sum_px[i] + np.sum(px, axis=0)
+        # Second order stats [array of shape (n_gaussians, n_features)]
+        statistics.sum_pxx[i] = statistics.sum_pxx[i] + np.sum(
+            px * data, axis=0
+        )
+
+    # px = np.multiply(responsibility[:, :, None], data[None, :, :])
+    # statistics.sum_px = statistics.sum_px + px.sum(axis=1)
+    # pxx = np.multiply(px[:, :, :], data[None, :, :])
+    # statistics.sum_pxx = statistics.sum_pxx + pxx.sum(axis=1)
+
+    return statistics
+
+
+def m_step(
+    machine,
+    statistics,
+    update_means,
+    update_variances,
+    update_weights,
+    mean_var_update_threshold,
+    map_relevance_factor,
+    map_alpha,
+    trainer,
+):
+    m_step_func = map_gmm_m_step if trainer == "map" else ml_gmm_m_step
+    statistics = functools.reduce(operator.iadd, statistics)
+    m_step_func(
+        machine,
+        statistics=statistics,
+        update_means=update_means,
+        update_variances=update_variances,
+        update_weights=update_weights,
+        mean_var_update_threshold=mean_var_update_threshold,
+        reynolds_adaptation=map_relevance_factor is not None,
+        alpha=map_alpha,
+        relevance_factor=map_relevance_factor,
+    )
+    average_output = float(statistics.log_likelihood / statistics.t)
+    return average_output
+
+
 class GMMStats:
     """Stores accumulated statistics of a GMM.
 
@@ -403,9 +503,7 @@ class GMMMachine(BaseEstimator):
         self._variances = np.maximum(self.variance_thresholds, variances)
         # Recompute g_norm for each gaussian [array of shape (n_gaussians,)]
         n_log_2pi = self._variances.shape[-1] * np.log(2 * np.pi)
-        self._g_norms = np.array(
-            n_log_2pi + np.log(self._variances).sum(axis=-1)
-        )
+        self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1)
 
     @property
     def variance_thresholds(self):
@@ -580,9 +678,8 @@ class GMMMachine(BaseEstimator):
             ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data)
 
             # Set the GMM machine's gaussians with the results of k-means
-            self.means = np.array(copy.deepcopy(kmeans_machine.centroids_))
-            self.variances = np.array(copy.deepcopy(variances))
-            self.weights = np.array(copy.deepcopy(weights))
+            self.means = copy.deepcopy(kmeans_machine.centroids_)
+            self.variances, self.weights = dask.compute(variances, weights)
 
     def log_weighted_likelihood(
         self,
@@ -735,6 +832,9 @@ class GMMMachine(BaseEstimator):
 
     def fit(self, X, y=None):
         """Trains the GMM on data until convergence or maximum step is reached."""
+
+        input_is_dask = check_and_persist_dask_input(X)
+
         if self._means is None:
             self.initialize_gaussians(X)
         else:
@@ -746,6 +846,19 @@ class GMMMachine(BaseEstimator):
             )
             self.variances = np.ones_like(self.means)
 
+        m_step_func = functools.partial(
+            m_step,
+            update_means=self.update_means,
+            update_variances=self.update_variances,
+            update_weights=self.update_weights,
+            mean_var_update_threshold=self.mean_var_update_threshold,
+            map_relevance_factor=self.map_relevance_factor,
+            map_alpha=self.map_alpha,
+            trainer=self.trainer,
+        )
+
+        X = array_to_delayed_list(X, input_is_dask)
+
         average_output = 0
         logger.info("Training GMM...")
         step = 0
@@ -761,23 +874,37 @@ class GMMMachine(BaseEstimator):
             )
 
             average_output_previous = average_output
-            stats = self.e_step(X)
-            self.m_step(
-                stats=stats,
-            )
 
-            # if we're running in dask, persist weights, means, and variances so
-            # we don't recompute each step.
-            for attr in ["weights", "means", "variances"]:
-                arr = getattr(self, attr)
-                if isinstance(arr, da.Array):
-                    setattr(self, attr, arr.persist())
+            # compute the e-m steps
+            if input_is_dask:
+                stats = [
+                    dask.delayed(e_step)(
+                        data=xx,
+                        weights=self.weights,
+                        means=self.means,
+                        variances=self.variances,
+                        g_norms=self.g_norms,
+                        log_weights=self.log_weights,
+                    )
+                    for xx in X
+                ]
+                average_output = dask.compute(
+                    dask.delayed(m_step_func)(self, stats)
+                )[0]
+            else:
+                stats = [
+                    e_step(
+                        data=X,
+                        weights=self.weights,
+                        means=self.means,
+                        variances=self.variances,
+                        g_norms=self.g_norms,
+                        log_weights=self.log_weights,
+                    )
+                ]
+                average_output = m_step_func(self, stats)
 
-            # Note: Uses the stats from before m_step, leading to an additional m_step
-            # (which is not bad because it will always converge)
-            average_output = float(stats.log_likelihood / stats.t)
             logger.debug(f"log likelihood = {average_output}")
-
             if step > 1:
                 convergence_value = abs(
                     (average_output_previous - average_output)
@@ -794,6 +921,7 @@ class GMMMachine(BaseEstimator):
                         "Reached convergence threshold. Training stopped."
                     )
                     break
+
         else:
             logger.info(
                 "Reached maximum step. Training stopped without convergence."
-- 
GitLab