Skip to content
Snippets Groups Projects
Commit 3d5a0df6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

persist the state of GMM between each training iteration

report the progress of training using computed values
replace variance_threshold with a number instead of an array
although arrays are still supported.
do not initialize means, variances, and weights when another
attr is being set.
parent b4921b3d
No related branches found
No related tags found
2 merge requests!42GMM implementation in Python,!40Transition to a pure python implementation
Pipeline #56623 failed
......@@ -170,9 +170,9 @@ class KMeansMachine(BaseEstimator):
trainer.e_step(machine=self, data=X)
trainer.m_step(machine=self, data=X)
distance = trainer.compute_likelihood(self)
distance = float(trainer.compute_likelihood(self))
# logger.info(f"Average squared Euclidean distance = {distance.compute()}")
logger.info(f"Average squared Euclidean distance = {distance}")
if step > 0:
convergence_value = abs(
......
......@@ -192,7 +192,7 @@ class GMMStats:
def compute(self):
for name in ("log_likelihood", "n", "sum_px", "sum_pxx"):
setattr(self, name, np.array(getattr(self, name)))
setattr(self, name, np.asarray(getattr(self, name)))
class GMMMachine(BaseEstimator):
......@@ -281,6 +281,7 @@ class GMMMachine(BaseEstimator):
mean_var_update_threshold: float = EPSILON,
alpha: float = 0.5,
relevance_factor: Union[None, float] = 4,
variance_thresholds: float = EPSILON,
):
"""
Parameters
......@@ -318,6 +319,9 @@ class GMMMachine(BaseEstimator):
relevance_factor:
Factor for the computation of alpha with Reyolds adaptation. (Used when
`trainer == "map"`)
variance_thresholds:
The variance flooring thresholds, i.e. the minimum allowed value of variance in each dimension.
The variance will be set to this value if an attempt is made to set it to a smaller value.
"""
self.n_gaussians = n_gaussians
......@@ -342,7 +346,7 @@ class GMMMachine(BaseEstimator):
self.mean_var_update_threshold = mean_var_update_threshold
self._means = None
self._variances = None
self._variance_thresholds = None
self._variance_thresholds = mean_var_update_threshold
self._g_norms = None
if self.ubm is not None:
......@@ -366,7 +370,7 @@ class GMMMachine(BaseEstimator):
@weights.setter
def weights(self, weights: "np.ndarray[('n_gaussians',), float]"):
self._weights = np.array(weights)
self._weights = weights
self._log_weights = np.log(self._weights)
@property
......@@ -378,11 +382,6 @@ class GMMMachine(BaseEstimator):
@means.setter
def means(self, means: "np.ndarray[('n_gaussians', 'n_features'), float]"):
if self._means is None:
if self._variances is None:
self._variances = np.ones_like(means, dtype=float)
if self._variance_thresholds is None:
self._variance_thresholds = np.full_like(means, fill_value=EPSILON, dtype=float)
self._means = means
@property
......@@ -394,11 +393,6 @@ class GMMMachine(BaseEstimator):
@variances.setter
def variances(self, variances: "np.ndarray[('n_gaussians', 'n_features'), float]"):
if self._variances is None:
if self._means is None:
self._means = np.zeros_like(variances, dtype=float)
if self._variance_thresholds is None:
self._variance_thresholds = np.full_like(variances, fill_value=EPSILON, dtype=float)
self._variances = np.maximum(self.variance_thresholds, variances)
# Recompute g_norm for each gaussian [array of shape (n_gaussians,)]
n_log_2pi = self.variances.shape[-1] * np.log(2 * np.pi)
......@@ -416,10 +410,6 @@ class GMMMachine(BaseEstimator):
self,
threshold: "Union[float, np.ndarray[('n_gaussians', 'n_features'), float]]",
):
if not hasattr(threshold, "ndim"):
threshold = np.full_like(self.means, fill_value=threshold, dtype=float)
elif threshold.ndim == 1:
threshold = threshold[None,:].repeat(self.n_gaussians, axis=0)
self._variance_thresholds = threshold
self.variances = np.maximum(threshold, self.variances)
......@@ -432,7 +422,6 @@ class GMMMachine(BaseEstimator):
self._g_norms = n_log_2pi + np.log(self._variances).sum(axis=-1)
return self._g_norms
@property
def log_weights(self):
"""Retrieve the logarithm of the weights."""
......@@ -486,12 +475,16 @@ class GMMMachine(BaseEstimator):
gaussian_group = hdf5[f"m_gaussians{i}"]
g_means.append(gaussian_group["m_mean"][()])
g_variances.append(gaussian_group["m_variance"][()])
g_variance_thresholds.append(gaussian_group["m_variance_thresholds"][()])
g_variance_thresholds.append(
gaussian_group["m_variance_thresholds"][()]
)
weights = hdf5["m_weights"][()].reshape(n_gaussians)
self = cls(n_gaussians=n_gaussians, ubm=ubm, weights=weights)
self.means = np.array(g_means).reshape(n_gaussians, -1)
self.variances = np.array(g_variances).reshape(n_gaussians, -1)
self.variance_thresholds = np.array(g_variance_thresholds).reshape(n_gaussians,-1)
self.variance_thresholds = np.array(g_variance_thresholds).reshape(
n_gaussians, -1
)
return self
def save(self, hdf5):
......@@ -526,7 +519,12 @@ class GMMMachine(BaseEstimator):
return (
np.allclose(self.means, other.means, rtol=rtol, atol=atol)
and np.allclose(self.variances, other.variances, rtol=rtol, atol=atol)
and np.allclose(self.variance_thresholds, other.variance_thresholds, rtol=rtol, atol=atol)
and np.allclose(
self.variance_thresholds,
other.variance_thresholds,
rtol=rtol,
atol=atol,
)
and np.allclose(self.weights, other.weights, rtol=rtol, atol=atol)
)
......@@ -536,12 +534,10 @@ class GMMMachine(BaseEstimator):
"""Populates gaussians parameters with either k-means or the UBM values."""
if self.trainer == "map":
self.means = copy.deepcopy(self.ubm.means)
self.variances_ = copy.deepcopy(self.ubm.variances)
self.variances = copy.deepcopy(self.ubm.variances)
self.variance_thresholds = copy.deepcopy(self.ubm.variance_thresholds)
self.weights = copy.deepcopy(self.ubm.weights)
self._g_norms = copy.deepcopy(self.ubm.g_norms)
else:
if self._means is None:
logger.debug("GMM means was never set. Initializing with k-means.")
if data is None:
raise ValueError("Data is required when training with k-means.")
......@@ -562,8 +558,6 @@ class GMMMachine(BaseEstimator):
self.means = np.array(copy.deepcopy(kmeans_machine.centroids_))
self.variances = np.array(copy.deepcopy(variances))
self.weights = np.array(copy.deepcopy(weights))
else:
logger.debug("Initialize GMMMachine: GMM means already set.")
def log_weighted_likelihood(
self, data: "np.ndarray[('n_samples', 'n_features'), float]"
......@@ -581,7 +575,10 @@ class GMMMachine(BaseEstimator):
The weighted log likelihood of each sample of each Gaussian.
"""
# Compute the likelihood for each data point on each Gaussian
z = ((data[None, ..., :] - self.means[..., None, :]) ** 2 / self.variances[..., None, :]).sum(axis=-1)
z = (
(data[None, ..., :] - self.means[..., None, :]) ** 2
/ self.variances[..., None, :]
).sum(axis=-1)
l = -0.5 * (self.g_norms[:, None] + z)
log_weighted_likelihood = self.log_weights[:, None] + l
return log_weighted_likelihood
......@@ -609,6 +606,7 @@ class GMMMachine(BaseEstimator):
return np.logaddexp.reduce(
array, axis=axis, keepdims=keepdims, initial=-np.inf
)
if isinstance(log_weighted_likelihood, np.ndarray):
ll_reduced = logaddexp_reduce(log_weighted_likelihood)
else:
......@@ -705,6 +703,8 @@ class GMMMachine(BaseEstimator):
"""Trains the GMM on data until convergence or maximum step is reached."""
if self._means is None:
self.initialize_gaussians(X)
else:
logger.debug("GMM means already set. Initialization was not run!")
average_output = 0
logger.info("Training GMM...")
......@@ -726,10 +726,17 @@ class GMMMachine(BaseEstimator):
stats=stats,
)
# if we're running in dask, persist weights, means, and variances so
# we don't recompute each step.
for attr in ["weights", "means", "variances"]:
arr = getattr(self, attr)
if isinstance(arr, da.Array):
setattr(self, attr, arr.persist())
# Note: Uses the stats from before m_step, leading to an additional m_step
# (which is not bad because it will always converge)
average_output = stats.log_likelihood / stats.t
logger.debug(f"average output = {average_output}")
average_output = float(stats.log_likelihood / stats.t)
logger.debug(f"log likelihood = {average_output}")
if step > 1:
convergence_value = abs(
......@@ -768,7 +775,8 @@ class GMMMachine(BaseEstimator):
}
def compute(self, *args, **kwargs):
setattr(self, "weights", np.array(getattr(self, "weights")))
for name in ("weights", "means", "variances"):
setattr(self, name, np.asarray(getattr(self, name)))
def ml_gmm_m_step(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment