From 3d5a0df6cd52d7debbbb23729d3461ec2b1fdec5 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 29 Nov 2021 15:29:49 +0100
Subject: [PATCH] persist the state of GMM between each training iteration
 report the progress of training using computed values replace
 variance_threshold with a number instead of an array although arrays are
 still supported. do not initialize means, variances, and weights when another
 attr is being set.

---
 bob/learn/em/cluster/k_means.py |   4 +-
 bob/learn/em/mixture/gmm.py     | 110 +++++++++++++++++---------------
 2 files changed, 61 insertions(+), 53 deletions(-)

diff --git a/bob/learn/em/cluster/k_means.py b/bob/learn/em/cluster/k_means.py
index 98418b4..706413e 100644
--- a/bob/learn/em/cluster/k_means.py
+++ b/bob/learn/em/cluster/k_means.py
@@ -170,9 +170,9 @@ class KMeansMachine(BaseEstimator):
             trainer.e_step(machine=self, data=X)
             trainer.m_step(machine=self, data=X)
 
-            distance = trainer.compute_likelihood(self)
+            distance = float(trainer.compute_likelihood(self))
 
-            # logger.info(f"Average squared Euclidean distance = {distance.compute()}")
+            logger.info(f"Average squared Euclidean distance = {distance}")
 
             if step > 0:
                 convergence_value = abs(
diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py
index 5c8c0c0..d17c7ae 100644
--- a/bob/learn/em/mixture/gmm.py
+++ b/bob/learn/em/mixture/gmm.py
@@ -192,7 +192,7 @@ class GMMStats:
 
     def compute(self):
         for name in ("log_likelihood", "n", "sum_px", "sum_pxx"):
-            setattr(self, name, np.array(getattr(self, name)))
+            setattr(self, name, np.asarray(getattr(self, name)))
 
 
 class GMMMachine(BaseEstimator):
@@ -281,6 +281,7 @@ class GMMMachine(BaseEstimator):
         mean_var_update_threshold: float = EPSILON,
         alpha: float = 0.5,
         relevance_factor: Union[None, float] = 4,
+        variance_thresholds: float = EPSILON,
     ):
         """
         Parameters
@@ -318,6 +319,9 @@ class GMMMachine(BaseEstimator):
         relevance_factor:
             Factor for the computation of alpha with Reyolds adaptation. (Used when
             `trainer == "map"`)
+        variance_thresholds:
+            The variance flooring thresholds, i.e. the minimum allowed value of variance in each dimension.
+            The variance will be set to this value if an attempt is made to set it to a smaller value.
         """
 
         self.n_gaussians = n_gaussians
@@ -342,7 +346,7 @@ class GMMMachine(BaseEstimator):
         self.mean_var_update_threshold = mean_var_update_threshold
         self._means = None
         self._variances = None
-        self._variance_thresholds = None
+        self._variance_thresholds = mean_var_update_threshold
         self._g_norms = None
 
         if self.ubm is not None:
@@ -357,7 +361,7 @@ class GMMMachine(BaseEstimator):
         if weights is not None:
             self.weights = weights
         self.alpha = alpha
-        self.relevance_factor = relevance_factor 
+        self.relevance_factor = relevance_factor
 
     @property
     def weights(self):
@@ -366,7 +370,7 @@ class GMMMachine(BaseEstimator):
 
     @weights.setter
     def weights(self, weights: "np.ndarray[('n_gaussians',), float]"):
-        self._weights = np.array(weights)
+        self._weights = weights
         self._log_weights = np.log(self._weights)
 
     @property
@@ -378,11 +382,6 @@ class GMMMachine(BaseEstimator):
 
     @means.setter
     def means(self, means: "np.ndarray[('n_gaussians', 'n_features'), float]"):
-        if self._means is None:
-            if self._variances is None:
-                self._variances = np.ones_like(means, dtype=float)
-            if self._variance_thresholds is None:
-                self._variance_thresholds = np.full_like(means, fill_value=EPSILON, dtype=float)
         self._means = means
 
     @property
@@ -394,11 +393,6 @@ class GMMMachine(BaseEstimator):
 
     @variances.setter
     def variances(self, variances: "np.ndarray[('n_gaussians', 'n_features'), float]"):
-        if self._variances is None:
-            if self._means is None:
-                self._means = np.zeros_like(variances, dtype=float)
-            if self._variance_thresholds is None:
-                self._variance_thresholds = np.full_like(variances, fill_value=EPSILON, dtype=float)
         self._variances = np.maximum(self.variance_thresholds, variances)
         # Recompute g_norm for each gaussian [array of shape (n_gaussians,)]
         n_log_2pi = self.variances.shape[-1] * np.log(2 * np.pi)
@@ -416,10 +410,6 @@ class GMMMachine(BaseEstimator):
         self,
         threshold: "Union[float, np.ndarray[('n_gaussians', 'n_features'), float]]",
     ):
-        if not hasattr(threshold, "ndim"):
-            threshold = np.full_like(self.means, fill_value=threshold, dtype=float)
-        elif threshold.ndim == 1:
-            threshold = threshold[None,:].repeat(self.n_gaussians, axis=0)
         self._variance_thresholds = threshold
         self.variances = np.maximum(threshold, self.variances)
 
@@ -432,7 +422,6 @@ class GMMMachine(BaseEstimator):
             self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1)
         return self._g_norms
 
-
     @property
     def log_weights(self):
         """Retrieve the logarithm of the weights."""
@@ -486,12 +475,16 @@ class GMMMachine(BaseEstimator):
                 gaussian_group = hdf5[f"m_gaussians{i}"]
                 g_means.append(gaussian_group["m_mean"][()])
                 g_variances.append(gaussian_group["m_variance"][()])
-                g_variance_thresholds.append(gaussian_group["m_variance_thresholds"][()])
+                g_variance_thresholds.append(
+                    gaussian_group["m_variance_thresholds"][()]
+                )
             weights = hdf5["m_weights"][()].reshape(n_gaussians)
             self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights)
-            self.means = np.array(g_means).reshape(n_gaussians,-1)
-            self.variances = np.array(g_variances).reshape(n_gaussians,-1)
-            self.variance_thresholds = np.array(g_variance_thresholds).reshape(n_gaussians,-1)
+            self.means = np.array(g_means).reshape(n_gaussians, -1)
+            self.variances = np.array(g_variances).reshape(n_gaussians, -1)
+            self.variance_thresholds = np.array(g_variance_thresholds).reshape(
+                n_gaussians, -1
+            )
         return self
 
     def save(self, hdf5):
@@ -526,7 +519,12 @@ class GMMMachine(BaseEstimator):
         return (
             np.allclose(self.means, other.means, rtol=rtol, atol=atol)
             and np.allclose(self.variances, other.variances, rtol=rtol, atol=atol)
-            and np.allclose(self.variance_thresholds, other.variance_thresholds, rtol=rtol, atol=atol)
+            and np.allclose(
+                self.variance_thresholds,
+                other.variance_thresholds,
+                rtol=rtol,
+                atol=atol,
+            )
             and np.allclose(self.weights, other.weights, rtol=rtol, atol=atol)
         )
 
@@ -536,34 +534,30 @@ class GMMMachine(BaseEstimator):
         """Populates gaussians parameters with either k-means or the UBM values."""
         if self.trainer == "map":
             self.means = copy.deepcopy(self.ubm.means)
-            self.variances_ = copy.deepcopy(self.ubm.variances)
+            self.variances = copy.deepcopy(self.ubm.variances)
             self.variance_thresholds = copy.deepcopy(self.ubm.variance_thresholds)
             self.weights = copy.deepcopy(self.ubm.weights)
-            self._g_norms = copy.deepcopy(self.ubm.g_norms)
         else:
-            if self._means is None:
-                logger.debug("GMM means was never set. Initializing with k-means.")
-                if data is None:
-                    raise ValueError("Data is required when training with k-means.")
-                logger.info("Initializing GMM with k-means.")
-                kmeans_trainer = self.k_means_trainer or KMeansTrainer(
-                    random_state=self.random_state,
-                )
-                kmeans_machine = KMeansMachine(self.n_gaussians).fit(
-                    data, trainer=kmeans_trainer
-                )
+            logger.debug("GMM means was never set. Initializing with k-means.")
+            if data is None:
+                raise ValueError("Data is required when training with k-means.")
+            logger.info("Initializing GMM with k-means.")
+            kmeans_trainer = self.k_means_trainer or KMeansTrainer(
+                random_state=self.random_state,
+            )
+            kmeans_machine = KMeansMachine(self.n_gaussians).fit(
+                data, trainer=kmeans_trainer
+            )
 
-                (
-                    variances,
-                    weights,
-                ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data)
+            (
+                variances,
+                weights,
+            ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data)
 
-                # Set the GMM machine's gaussians with the results of k-means
-                self.means = np.array(copy.deepcopy(kmeans_machine.centroids_))
-                self.variances = np.array(copy.deepcopy(variances))
-                self.weights = np.array(copy.deepcopy(weights))
-            else:
-                logger.debug("Initialize GMMMachine: GMM means already set.")
+            # Set the GMM machine's gaussians with the results of k-means
+            self.means = np.array(copy.deepcopy(kmeans_machine.centroids_))
+            self.variances = np.array(copy.deepcopy(variances))
+            self.weights = np.array(copy.deepcopy(weights))
 
     def log_weighted_likelihood(
         self, data: "np.ndarray[('n_samples', 'n_features'), float]"
@@ -581,7 +575,10 @@ class GMMMachine(BaseEstimator):
             The weighted log likelihood of each sample of each Gaussian.
         """
         # Compute the likelihood for each data point on each Gaussian
-        z = ((data[None, ..., :] - self.means[..., None, :]) ** 2 / self.variances[..., None, :]).sum(axis=-1)
+        z = (
+            (data[None, ..., :] - self.means[..., None, :]) ** 2
+            / self.variances[..., None, :]
+        ).sum(axis=-1)
         l = -0.5 * (self.g_norms[:, None] + z)
         log_weighted_likelihood = self.log_weights[:, None] + l
         return log_weighted_likelihood
@@ -609,6 +606,7 @@ class GMMMachine(BaseEstimator):
             return np.logaddexp.reduce(
                 array, axis=axis, keepdims=keepdims, initial=-np.inf
             )
+
         if isinstance(log_weighted_likelihood, np.ndarray):
             ll_reduced = logaddexp_reduce(log_weighted_likelihood)
         else:
@@ -705,6 +703,8 @@ class GMMMachine(BaseEstimator):
         """Trains the GMM on data until convergence or maximum step is reached."""
         if self._means is None:
             self.initialize_gaussians(X)
+        else:
+            logger.debug("GMM means already set. Initialization was not run!")
 
         average_output = 0
         logger.info("Training GMM...")
@@ -726,10 +726,17 @@ class GMMMachine(BaseEstimator):
                 stats=stats,
             )
 
+            # if we're running in dask, persist weights, means, and variances so
+            # we don't recompute each step.
+            for attr in ["weights", "means", "variances"]:
+                arr = getattr(self, attr)
+                if isinstance(arr, da.Array):
+                    setattr(self, attr, arr.persist())
+
             # Note: Uses the stats from before m_step, leading to an additional m_step
             # (which is not bad because it will always converge)
-            average_output = stats.log_likelihood / stats.t
-            logger.debug(f"average output = {average_output}")
+            average_output = float(stats.log_likelihood / stats.t)
+            logger.debug(f"log likelihood = {average_output}")
 
             if step > 1:
                 convergence_value = abs(
@@ -768,7 +775,8 @@ class GMMMachine(BaseEstimator):
         }
 
     def compute(self, *args, **kwargs):
-        setattr(self, "weights", np.array(getattr(self, "weights")))
+        for name in ("weights", "means", "variances"):
+            setattr(self, name, np.asarray(getattr(self, name)))
 
 
 def ml_gmm_m_step(
-- 
GitLab