From ef1bcd4529ff37e0073faa63dc8961f6756cc573 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Tue, 26 Apr 2022 19:57:48 +0200 Subject: [PATCH] rename seed to random_state --- bob/learn/em/factor_analysis.py | 45 +++++++++++++---------- bob/learn/em/gmm.py | 7 +++- bob/learn/em/test/test_factor_analysis.py | 2 +- 3 files changed, 33 insertions(+), 21 deletions(-) diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index ea66294..779068b 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -112,8 +112,8 @@ class FactorAnalysisBase(BaseEstimator): relevance_factor: float Factor analysis relevance factor - seed: int - Seed for the random number generator + random_state: int + random_state for the random number generator ubm: :py:class:`bob.learn.em.GMMMachine` A trained UBM (Universal Background Model) or a parametrized @@ -128,7 +128,7 @@ class FactorAnalysisBase(BaseEstimator): r_V=None, relevance_factor=4.0, em_iterations=10, - seed=0, + random_state=0, ubm=None, ubm_kwargs=None, **kwargs, @@ -137,7 +137,7 @@ class FactorAnalysisBase(BaseEstimator): self.ubm = ubm self.ubm_kwargs = ubm_kwargs self.em_iterations = em_iterations - self.seed = seed + self.random_state = random_state # axis 1 dimensions of U and V self.r_U = r_U @@ -231,7 +231,8 @@ class FactorAnalysisBase(BaseEstimator): if self.ubm is None: logger.info("FA: Creating a new GMMMachine and training it.") - self.ubm = GMMMachine(**self.ubm_kwargs) + gmm_class = self.ubm_kwargs.pop("gmm_class", GMMMachine) + self.ubm: GMMMachine = gmm_class(**self.ubm_kwargs) self.ubm.fit(X) if self.ubm._means is None: @@ -280,8 +281,8 @@ class FactorAnalysisBase(BaseEstimator): D: (n_gaussians*feature_dimension) represents the client offset vector """ - if self.seed is not None: - np.random.seed(self.seed) + if self.random_state is not None: + np.random.seed(self.random_state) U_shape = (self.supervector_dimension, self.r_U) @@ -1258,7 +1259,7 @@ class FactorAnalysisBase(BaseEstimator): """ - return self.score_using_stats(model, self.ubm.transform(data)) + return self.score_using_stats(model, self.ubm.acc_stats(data)) def fit(self, X, y): @@ -1275,12 +1276,18 @@ class FactorAnalysisBase(BaseEstimator): class_indices = y == class_id X_new.append(X[class_indices]) y_new.append(y[class_indices]) + __import__("ipdb").set_trace() X, y = X_new, y_new stats = [ dask.delayed(self.ubm.stats_per_sample)(xx).persist() for xx in X ] + # try: + # client = dask.distributed.Client.current() + # stats = client.scatter(stats) + # except ValueError: + # pass else: stats = self.ubm.stats_per_sample(X) @@ -1309,8 +1316,8 @@ class ISVMachine(FactorAnalysisBase): relevance_factor: float Factor analysis relevance factor - seed: int - Seed for the random number generator + random_state: int + random_state for the random number generator ubm: :py:class:`bob.learn.em.GMMMachine` or None A trained UBM (Universal Background Model). If None, the UBM is trained with @@ -1324,7 +1331,7 @@ class ISVMachine(FactorAnalysisBase): r_U, em_iterations=10, relevance_factor=4.0, - seed=0, + random_state=0, ubm=None, ubm_kwargs=None, **kwargs, @@ -1333,7 +1340,7 @@ class ISVMachine(FactorAnalysisBase): r_U=r_U, relevance_factor=relevance_factor, em_iterations=em_iterations, - seed=seed, + random_state=random_state, ubm=ubm, ubm_kwargs=ubm_kwargs, **kwargs, @@ -1447,7 +1454,7 @@ class ISVMachine(FactorAnalysisBase): return self def transform(self, X): - ubm_projected_X = self.ubm.transform(X) + ubm_projected_X = self.ubm.acc_stats(X) return self.estimate_ux(ubm_projected_X) def enroll_using_stats(self, X, iterations=1): @@ -1519,7 +1526,7 @@ class ISVMachine(FactorAnalysisBase): z """ - return self.enroll_using_stats([self.ubm.transform(X)], iterations) + return self.enroll_using_stats([self.ubm.acc_stats(X)], iterations) def score_using_stats(self, latent_z, data): """ @@ -1584,8 +1591,8 @@ class JFAMachine(FactorAnalysisBase): relevance_factor: float Factor analysis relevance factor - seed: int - Seed for the random number generator + random_state: int + random_state for the random number generator """ @@ -1595,7 +1602,7 @@ class JFAMachine(FactorAnalysisBase): r_V, em_iterations=10, relevance_factor=4.0, - seed=0, + random_state=0, ubm=None, ubm_kwargs=None, **kwargs, @@ -1606,7 +1613,7 @@ class JFAMachine(FactorAnalysisBase): r_V=r_V, relevance_factor=relevance_factor, em_iterations=em_iterations, - seed=seed, + random_state=random_state, ubm_kwargs=ubm_kwargs, **kwargs, ) @@ -2059,7 +2066,7 @@ class JFAMachine(FactorAnalysisBase): z """ - return self.enroll_using_stats([self.ubm.transform(X)], iterations) + return self.enroll_using_stats([self.ubm.acc_stats(X)], iterations) def fit_using_stats(self, X, y): """ diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 25159f1..988ef3a 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -850,10 +850,15 @@ class GMMMachine(BaseEstimator): ) return self - def transform(self, X): + def acc_stats(self, X): """Returns the statistics for `X`.""" + # we need this because sometimes the transform function gets overridden return e_step(data=X, machine=self) + def transform(self, X): + """Returns the statistics for `X`.""" + return self.acc_stats(X) + def stats_per_sample(self, X): return [e_step(data=xx, machine=self) for xx in X] diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py index d2399b2..596a6a9 100644 --- a/bob/learn/em/test/test_factor_analysis.py +++ b/bob/learn/em/test/test_factor_analysis.py @@ -557,7 +557,7 @@ def test_ISV_JFA_fit(): relevance_factor=4, em_iterations=50, ubm_kwargs=ubm_kwargs, - seed=10, + random_state=10, ) if machine_type == "isv": -- GitLab