diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index ce742f8225228258461e7b0b3fef5a285855fac1..c2ebe217fc5884a47e5be9c4b837b4f2ef12d411 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -212,7 +212,9 @@ class GMMStats: return (self.n_gaussians, self.n_features) def compute(self): - for name in ("log_likelihood", "n", "sum_px", "sum_pxx"): + for name in ("log_likelihood", "t"): + setattr(self, name, float(getattr(self, name))) + for name in ("n", "sum_px", "sum_pxx"): setattr(self, name, np.asarray(getattr(self, name))) @@ -272,9 +274,8 @@ class GMMMachine(BaseEstimator): update_variances: bool = False, update_weights: bool = False, mean_var_update_threshold: float = EPSILON, - alpha: float = 0.5, - relevance_factor: Union[None, float] = 4, - variance_thresholds: float = EPSILON, + map_alpha: float = 0.5, + map_relevance_factor: Union[None, float] = 4, ): """ Parameters @@ -304,17 +305,14 @@ class GMMMachine(BaseEstimator): Update the Gaussians variances at every m step. update_weights Update the GMM weights at every m step. - mean_var_update_threshold: + mean_var_update_threshold Threshold value used when updating the means and variances. - alpha: + map_alpha Ratio for MAP adaptation. Used when `trainer == "map"` and `relevance_factor is None`) - relevance_factor: + map_relevance_factor Factor for the computation of alpha with Reynolds adaptation. (Used when `trainer == "map"`) - variance_thresholds: - The variance flooring thresholds, i.e. the minimum allowed value of variance in each dimension. - The variance will be set to this value if an attempt is made to set it to a smaller value. """ self.n_gaussians = n_gaussians @@ -361,8 +359,8 @@ class GMMMachine(BaseEstimator): ) if weights is not None: self.weights = weights - self.alpha = alpha - self.relevance_factor = relevance_factor + self.map_alpha = map_alpha + self.map_relevance_factor = map_relevance_factor @property def weights(self): @@ -729,9 +727,9 @@ class GMMMachine(BaseEstimator): update_variances=self.update_variances, update_weights=self.update_weights, mean_var_update_threshold=self.mean_var_update_threshold, - reynolds_adaptation=self.relevance_factor is not None, - alpha=self.alpha, - relevance_factor=self.relevance_factor, + reynolds_adaptation=self.map_relevance_factor is not None, + alpha=self.map_alpha, + relevance_factor=self.map_relevance_factor, **kwargs, ) @@ -839,14 +837,14 @@ def ml_gmm_m_step( """Updates a gmm machine parameter according to the e-step statistics.""" logger.debug("ML GMM Trainer m-step") + # Threshold the low n to prevent divide by zero + thresholded_n = np.clip(statistics.n, mean_var_update_threshold, None) + # Update weights if requested # (Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006) if update_weights: logger.debug("Update weights.") - machine.weights = statistics.n / statistics.t - - # Threshold the low n to prevent divide by zero - thresholded_n = np.clip(statistics.n, mean_var_update_threshold, None) + machine.weights = thresholded_n / statistics.t # Update GMM parameters using the sufficient statistics (m_ss): diff --git a/bob/learn/em/k_means.py b/bob/learn/em/k_means.py index 830b4e4da6cb613a7b8b846281e6558b72402100..3b73a610fd1beee0fbbd2f1a6844818d4965caa7 100644 --- a/bob/learn/em/k_means.py +++ b/bob/learn/em/k_means.py @@ -42,7 +42,8 @@ class KMeansMachine(BaseEstimator): convergence_threshold: float = 1e-5, max_iter: int = 20, random_state: Union[int, np.random.RandomState] = 0, - init_max_iter: Union[int, None] = None, + init_max_iter: Union[int, None] = 5, + oversampling_factor: float = 2, ) -> None: """ Parameters @@ -68,6 +69,7 @@ class KMeansMachine(BaseEstimator): self.max_iter = max_iter self.random_state = random_state self.init_max_iter = init_max_iter + self.oversampling_factor = oversampling_factor self.average_min_distance = np.inf self.zeroeth_order_statistics = None self.first_order_statistics = None @@ -193,6 +195,7 @@ class KMeansMachine(BaseEstimator): init=self.init_method, random_state=self.random_state, max_iter=self.init_max_iter, + oversampling_factor=self.oversampling_factor, ) def e_step(self, data: np.ndarray): @@ -226,7 +229,7 @@ class KMeansMachine(BaseEstimator): step += 1 logger.info( f"Iteration {step:3d}" - + (f"/{self.max_iter}" if self.max_iter else "") + + (f"/{self.max_iter:3d}" if self.max_iter else "") ) distance_previous = distance self.e_step(data=X) @@ -241,7 +244,7 @@ class KMeansMachine(BaseEstimator): distance = float(self.average_min_distance) - logger.info( + logger.debug( f"Average minimal squared Euclidean distance = {distance}" ) @@ -249,7 +252,7 @@ class KMeansMachine(BaseEstimator): convergence_value = abs( (distance_previous - distance) / distance_previous ) - logger.info(f"Convergence value = {convergence_value}") + logger.debug(f"Convergence value = {convergence_value}") # Terminates if converged (and threshold is set) if ( diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 51a8bed00cc29adb3278be5f9dbdb2ff26fdd8a6..7eda1c63df3a7827ce7631eb6a0c78ea0971f5f7 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -202,6 +202,7 @@ def test_GMMMachine(): with tempfile.NamedTemporaryFile(suffix=".hdf5") as f: filename = f.name gmm.save(HDF5File(filename, "w")) + # Using from_hdf5 gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r")) assert type(gmm1.n_gaussians) is np.int64 assert type(gmm1.update_means) is np.bool_ @@ -210,6 +211,17 @@ def test_GMMMachine(): assert type(gmm1.trainer) is str assert gmm1.ubm is None assert gmm == gmm1 + # Using load + gmm1 = GMMMachine(n_gaussians=gmm.n_gaussians) + gmm1.load(HDF5File(filename, "r")) + assert type(gmm1.n_gaussians) is np.int64 + assert type(gmm1.update_means) is np.bool_ + assert type(gmm1.update_variances) is np.bool_ + assert type(gmm1.update_weights) is np.bool_ + assert type(gmm1.trainer) is str + assert gmm1.ubm is None + assert gmm == gmm1 + with tempfile.NamedTemporaryFile(suffix=".hdf5") as f: filename = f.name gmm.save(filename) @@ -923,7 +935,7 @@ def test_gmm_MAP_3(): update_variances=False, update_weights=False, mean_var_update_threshold=accuracy, - relevance_factor=None, + map_relevance_factor=None, ) gmm.variance_thresholds = threshold @@ -1071,7 +1083,7 @@ def test_gmm_MAP_dask(): update_variances=False, update_weights=False, mean_var_update_threshold=accuracy, - relevance_factor=None, + map_relevance_factor=None, ) gmm.variance_thresholds = threshold