Skip to content
Snippets Groups Projects
Commit 3d5a0df6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

persist the state of GMM between each training iteration

report the progress of training using computed values
replace variance_threshold with a number instead of an array
although arrays are still supported.
do not initialize means, variances, and weights when another
attr is being set.
parent b4921b3d
No related branches found
No related tags found
2 merge requests!42GMM implementation in Python,!40Transition to a pure python implementation
Pipeline #56623 failed
...@@ -170,9 +170,9 @@ class KMeansMachine(BaseEstimator): ...@@ -170,9 +170,9 @@ class KMeansMachine(BaseEstimator):
trainer.e_step(machine=self, data=X) trainer.e_step(machine=self, data=X)
trainer.m_step(machine=self, data=X) trainer.m_step(machine=self, data=X)
distance = trainer.compute_likelihood(self) distance = float(trainer.compute_likelihood(self))
# logger.info(f"Average squared Euclidean distance = {distance.compute()}") logger.info(f"Average squared Euclidean distance = {distance}")
if step > 0: if step > 0:
convergence_value = abs( convergence_value = abs(
......
...@@ -192,7 +192,7 @@ class GMMStats: ...@@ -192,7 +192,7 @@ class GMMStats:
def compute(self): def compute(self):
for name in ("log_likelihood", "n", "sum_px", "sum_pxx"): for name in ("log_likelihood", "n", "sum_px", "sum_pxx"):
setattr(self, name, np.array(getattr(self, name))) setattr(self, name, np.asarray(getattr(self, name)))
class GMMMachine(BaseEstimator): class GMMMachine(BaseEstimator):
...@@ -281,6 +281,7 @@ class GMMMachine(BaseEstimator): ...@@ -281,6 +281,7 @@ class GMMMachine(BaseEstimator):
mean_var_update_threshold: float = EPSILON, mean_var_update_threshold: float = EPSILON,
alpha: float = 0.5, alpha: float = 0.5,
relevance_factor: Union[None, float] = 4, relevance_factor: Union[None, float] = 4,
variance_thresholds: float = EPSILON,
): ):
""" """
Parameters Parameters
...@@ -318,6 +319,9 @@ class GMMMachine(BaseEstimator): ...@@ -318,6 +319,9 @@ class GMMMachine(BaseEstimator):
relevance_factor: relevance_factor:
Factor for the computation of alpha with Reyolds adaptation. (Used when Factor for the computation of alpha with Reyolds adaptation. (Used when
`trainer == "map"`) `trainer == "map"`)
variance_thresholds:
The variance flooring thresholds, i.e. the minimum allowed value of variance in each dimension.
The variance will be set to this value if an attempt is made to set it to a smaller value.
""" """
self.n_gaussians = n_gaussians self.n_gaussians = n_gaussians
...@@ -342,7 +346,7 @@ class GMMMachine(BaseEstimator): ...@@ -342,7 +346,7 @@ class GMMMachine(BaseEstimator):
self.mean_var_update_threshold = mean_var_update_threshold self.mean_var_update_threshold = mean_var_update_threshold
self._means = None self._means = None
self._variances = None self._variances = None
self._variance_thresholds = None self._variance_thresholds = mean_var_update_threshold
self._g_norms = None self._g_norms = None
if self.ubm is not None: if self.ubm is not None:
...@@ -357,7 +361,7 @@ class GMMMachine(BaseEstimator): ...@@ -357,7 +361,7 @@ class GMMMachine(BaseEstimator):
if weights is not None: if weights is not None:
self.weights = weights self.weights = weights
self.alpha = alpha self.alpha = alpha
self.relevance_factor = relevance_factor self.relevance_factor = relevance_factor
@property @property
def weights(self): def weights(self):
...@@ -366,7 +370,7 @@ class GMMMachine(BaseEstimator): ...@@ -366,7 +370,7 @@ class GMMMachine(BaseEstimator):
@weights.setter @weights.setter
def weights(self, weights: "np.ndarray[('n_gaussians',), float]"): def weights(self, weights: "np.ndarray[('n_gaussians',), float]"):
self._weights = np.array(weights) self._weights = weights
self._log_weights = np.log(self._weights) self._log_weights = np.log(self._weights)
@property @property
...@@ -378,11 +382,6 @@ class GMMMachine(BaseEstimator): ...@@ -378,11 +382,6 @@ class GMMMachine(BaseEstimator):
@means.setter @means.setter
def means(self, means: "np.ndarray[('n_gaussians', 'n_features'), float]"): def means(self, means: "np.ndarray[('n_gaussians', 'n_features'), float]"):
if self._means is None:
if self._variances is None:
self._variances = np.ones_like(means, dtype=float)
if self._variance_thresholds is None:
self._variance_thresholds = np.full_like(means, fill_value=EPSILON, dtype=float)
self._means = means self._means = means
@property @property
...@@ -394,11 +393,6 @@ class GMMMachine(BaseEstimator): ...@@ -394,11 +393,6 @@ class GMMMachine(BaseEstimator):
@variances.setter @variances.setter
def variances(self, variances: "np.ndarray[('n_gaussians', 'n_features'), float]"): def variances(self, variances: "np.ndarray[('n_gaussians', 'n_features'), float]"):
if self._variances is None:
if self._means is None:
self._means = np.zeros_like(variances, dtype=float)
if self._variance_thresholds is None:
self._variance_thresholds = np.full_like(variances, fill_value=EPSILON, dtype=float)
self._variances = np.maximum(self.variance_thresholds, variances) self._variances = np.maximum(self.variance_thresholds, variances)
# Recompute g_norm for each gaussian [array of shape (n_gaussians,)] # Recompute g_norm for each gaussian [array of shape (n_gaussians,)]
n_log_2pi = self.variances.shape[-1] * np.log(2 * np.pi) n_log_2pi = self.variances.shape[-1] * np.log(2 * np.pi)
...@@ -416,10 +410,6 @@ class GMMMachine(BaseEstimator): ...@@ -416,10 +410,6 @@ class GMMMachine(BaseEstimator):
self, self,
threshold: "Union[float, np.ndarray[('n_gaussians', 'n_features'), float]]", threshold: "Union[float, np.ndarray[('n_gaussians', 'n_features'), float]]",
): ):
if not hasattr(threshold, "ndim"):
threshold = np.full_like(self.means, fill_value=threshold, dtype=float)
elif threshold.ndim == 1:
threshold = threshold[None,:].repeat(self.n_gaussians, axis=0)
self._variance_thresholds = threshold self._variance_thresholds = threshold
self.variances = np.maximum(threshold, self.variances) self.variances = np.maximum(threshold, self.variances)
...@@ -432,7 +422,6 @@ class GMMMachine(BaseEstimator): ...@@ -432,7 +422,6 @@ class GMMMachine(BaseEstimator):
self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1) self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1)
return self._g_norms return self._g_norms
@property @property
def log_weights(self): def log_weights(self):
"""Retrieve the logarithm of the weights.""" """Retrieve the logarithm of the weights."""
...@@ -486,12 +475,16 @@ class GMMMachine(BaseEstimator): ...@@ -486,12 +475,16 @@ class GMMMachine(BaseEstimator):
gaussian_group = hdf5[f"m_gaussians{i}"] gaussian_group = hdf5[f"m_gaussians{i}"]
g_means.append(gaussian_group["m_mean"][()]) g_means.append(gaussian_group["m_mean"][()])
g_variances.append(gaussian_group["m_variance"][()]) g_variances.append(gaussian_group["m_variance"][()])
g_variance_thresholds.append(gaussian_group["m_variance_thresholds"][()]) g_variance_thresholds.append(
gaussian_group["m_variance_thresholds"][()]
)
weights = hdf5["m_weights"][()].reshape(n_gaussians) weights = hdf5["m_weights"][()].reshape(n_gaussians)
self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights) self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights)
self.means = np.array(g_means).reshape(n_gaussians,-1) self.means = np.array(g_means).reshape(n_gaussians, -1)
self.variances = np.array(g_variances).reshape(n_gaussians,-1) self.variances = np.array(g_variances).reshape(n_gaussians, -1)
self.variance_thresholds = np.array(g_variance_thresholds).reshape(n_gaussians,-1) self.variance_thresholds = np.array(g_variance_thresholds).reshape(
n_gaussians, -1
)
return self return self
def save(self, hdf5): def save(self, hdf5):
...@@ -526,7 +519,12 @@ class GMMMachine(BaseEstimator): ...@@ -526,7 +519,12 @@ class GMMMachine(BaseEstimator):
return ( return (
np.allclose(self.means, other.means, rtol=rtol, atol=atol) np.allclose(self.means, other.means, rtol=rtol, atol=atol)
and np.allclose(self.variances, other.variances, rtol=rtol, atol=atol) and np.allclose(self.variances, other.variances, rtol=rtol, atol=atol)
and np.allclose(self.variance_thresholds, other.variance_thresholds, rtol=rtol, atol=atol) and np.allclose(
self.variance_thresholds,
other.variance_thresholds,
rtol=rtol,
atol=atol,
)
and np.allclose(self.weights, other.weights, rtol=rtol, atol=atol) and np.allclose(self.weights, other.weights, rtol=rtol, atol=atol)
) )
...@@ -536,34 +534,30 @@ class GMMMachine(BaseEstimator): ...@@ -536,34 +534,30 @@ class GMMMachine(BaseEstimator):
"""Populates gaussians parameters with either k-means or the UBM values.""" """Populates gaussians parameters with either k-means or the UBM values."""
if self.trainer == "map": if self.trainer == "map":
self.means = copy.deepcopy(self.ubm.means) self.means = copy.deepcopy(self.ubm.means)
self.variances_ = copy.deepcopy(self.ubm.variances) self.variances = copy.deepcopy(self.ubm.variances)
self.variance_thresholds = copy.deepcopy(self.ubm.variance_thresholds) self.variance_thresholds = copy.deepcopy(self.ubm.variance_thresholds)
self.weights = copy.deepcopy(self.ubm.weights) self.weights = copy.deepcopy(self.ubm.weights)
self._g_norms = copy.deepcopy(self.ubm.g_norms)
else: else:
if self._means is None: logger.debug("GMM means was never set. Initializing with k-means.")
logger.debug("GMM means was never set. Initializing with k-means.") if data is None:
if data is None: raise ValueError("Data is required when training with k-means.")
raise ValueError("Data is required when training with k-means.") logger.info("Initializing GMM with k-means.")
logger.info("Initializing GMM with k-means.") kmeans_trainer = self.k_means_trainer or KMeansTrainer(
kmeans_trainer = self.k_means_trainer or KMeansTrainer( random_state=self.random_state,
random_state=self.random_state, )
) kmeans_machine = KMeansMachine(self.n_gaussians).fit(
kmeans_machine = KMeansMachine(self.n_gaussians).fit( data, trainer=kmeans_trainer
data, trainer=kmeans_trainer )
)
( (
variances, variances,
weights, weights,
) = kmeans_machine.get_variances_and_weights_for_each_cluster(data) ) = kmeans_machine.get_variances_and_weights_for_each_cluster(data)
# Set the GMM machine's gaussians with the results of k-means # Set the GMM machine's gaussians with the results of k-means
self.means = np.array(copy.deepcopy(kmeans_machine.centroids_)) self.means = np.array(copy.deepcopy(kmeans_machine.centroids_))
self.variances = np.array(copy.deepcopy(variances)) self.variances = np.array(copy.deepcopy(variances))
self.weights = np.array(copy.deepcopy(weights)) self.weights = np.array(copy.deepcopy(weights))
else:
logger.debug("Initialize GMMMachine: GMM means already set.")
def log_weighted_likelihood( def log_weighted_likelihood(
self, data: "np.ndarray[('n_samples', 'n_features'), float]" self, data: "np.ndarray[('n_samples', 'n_features'), float]"
...@@ -581,7 +575,10 @@ class GMMMachine(BaseEstimator): ...@@ -581,7 +575,10 @@ class GMMMachine(BaseEstimator):
The weighted log likelihood of each sample of each Gaussian. The weighted log likelihood of each sample of each Gaussian.
""" """
# Compute the likelihood for each data point on each Gaussian # Compute the likelihood for each data point on each Gaussian
z = ((data[None, ..., :] - self.means[..., None, :]) ** 2 / self.variances[..., None, :]).sum(axis=-1) z = (
(data[None, ..., :] - self.means[..., None, :]) ** 2
/ self.variances[..., None, :]
).sum(axis=-1)
l = -0.5 * (self.g_norms[:, None] + z) l = -0.5 * (self.g_norms[:, None] + z)
log_weighted_likelihood = self.log_weights[:, None] + l log_weighted_likelihood = self.log_weights[:, None] + l
return log_weighted_likelihood return log_weighted_likelihood
...@@ -609,6 +606,7 @@ class GMMMachine(BaseEstimator): ...@@ -609,6 +606,7 @@ class GMMMachine(BaseEstimator):
return np.logaddexp.reduce( return np.logaddexp.reduce(
array, axis=axis, keepdims=keepdims, initial=-np.inf array, axis=axis, keepdims=keepdims, initial=-np.inf
) )
if isinstance(log_weighted_likelihood, np.ndarray): if isinstance(log_weighted_likelihood, np.ndarray):
ll_reduced = logaddexp_reduce(log_weighted_likelihood) ll_reduced = logaddexp_reduce(log_weighted_likelihood)
else: else:
...@@ -705,6 +703,8 @@ class GMMMachine(BaseEstimator): ...@@ -705,6 +703,8 @@ class GMMMachine(BaseEstimator):
"""Trains the GMM on data until convergence or maximum step is reached.""" """Trains the GMM on data until convergence or maximum step is reached."""
if self._means is None: if self._means is None:
self.initialize_gaussians(X) self.initialize_gaussians(X)
else:
logger.debug("GMM means already set. Initialization was not run!")
average_output = 0 average_output = 0
logger.info("Training GMM...") logger.info("Training GMM...")
...@@ -726,10 +726,17 @@ class GMMMachine(BaseEstimator): ...@@ -726,10 +726,17 @@ class GMMMachine(BaseEstimator):
stats=stats, stats=stats,
) )
# if we're running in dask, persist weights, means, and variances so
# we don't recompute each step.
for attr in ["weights", "means", "variances"]:
arr = getattr(self, attr)
if isinstance(arr, da.Array):
setattr(self, attr, arr.persist())
# Note: Uses the stats from before m_step, leading to an additional m_step # Note: Uses the stats from before m_step, leading to an additional m_step
# (which is not bad because it will always converge) # (which is not bad because it will always converge)
average_output = stats.log_likelihood / stats.t average_output = float(stats.log_likelihood / stats.t)
logger.debug(f"average output = {average_output}") logger.debug(f"log likelihood = {average_output}")
if step > 1: if step > 1:
convergence_value = abs( convergence_value = abs(
...@@ -768,7 +775,8 @@ class GMMMachine(BaseEstimator): ...@@ -768,7 +775,8 @@ class GMMMachine(BaseEstimator):
} }
def compute(self, *args, **kwargs): def compute(self, *args, **kwargs):
setattr(self, "weights", np.array(getattr(self, "weights"))) for name in ("weights", "means", "variances"):
setattr(self, name, np.asarray(getattr(self, name)))
def ml_gmm_m_step( def ml_gmm_m_step(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment