diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index ea66294e90fb9a93f35c86a75a35643eba2385f8..779068ba07708df1afdcf51e6c867d3b181fecbf 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -112,8 +112,8 @@ class FactorAnalysisBase(BaseEstimator): relevance_factor: float Factor analysis relevance factor - seed: int - Seed for the random number generator + random_state: int + random_state for the random number generator ubm: :py:class:`bob.learn.em.GMMMachine` A trained UBM (Universal Background Model) or a parametrized @@ -128,7 +128,7 @@ class FactorAnalysisBase(BaseEstimator): r_V=None, relevance_factor=4.0, em_iterations=10, - seed=0, + random_state=0, ubm=None, ubm_kwargs=None, **kwargs, @@ -137,7 +137,7 @@ class FactorAnalysisBase(BaseEstimator): self.ubm = ubm self.ubm_kwargs = ubm_kwargs self.em_iterations = em_iterations - self.seed = seed + self.random_state = random_state # axis 1 dimensions of U and V self.r_U = r_U @@ -231,7 +231,8 @@ class FactorAnalysisBase(BaseEstimator): if self.ubm is None: logger.info("FA: Creating a new GMMMachine and training it.") - self.ubm = GMMMachine(**self.ubm_kwargs) + gmm_class = self.ubm_kwargs.pop("gmm_class", GMMMachine) + self.ubm: GMMMachine = gmm_class(**self.ubm_kwargs) self.ubm.fit(X) if self.ubm._means is None: @@ -280,8 +281,8 @@ class FactorAnalysisBase(BaseEstimator): D: (n_gaussians*feature_dimension) represents the client offset vector """ - if self.seed is not None: - np.random.seed(self.seed) + if self.random_state is not None: + np.random.seed(self.random_state) U_shape = (self.supervector_dimension, self.r_U) @@ -1258,7 +1259,7 @@ class FactorAnalysisBase(BaseEstimator): """ - return self.score_using_stats(model, self.ubm.transform(data)) + return self.score_using_stats(model, self.ubm.acc_stats(data)) def fit(self, X, y): @@ -1275,12 +1276,18 @@ class FactorAnalysisBase(BaseEstimator): class_indices = y == class_id X_new.append(X[class_indices]) y_new.append(y[class_indices]) + __import__("ipdb").set_trace() X, y = X_new, y_new stats = [ dask.delayed(self.ubm.stats_per_sample)(xx).persist() for xx in X ] + # try: + # client = dask.distributed.Client.current() + # stats = client.scatter(stats) + # except ValueError: + # pass else: stats = self.ubm.stats_per_sample(X) @@ -1309,8 +1316,8 @@ class ISVMachine(FactorAnalysisBase): relevance_factor: float Factor analysis relevance factor - seed: int - Seed for the random number generator + random_state: int + random_state for the random number generator ubm: :py:class:`bob.learn.em.GMMMachine` or None A trained UBM (Universal Background Model). If None, the UBM is trained with @@ -1324,7 +1331,7 @@ class ISVMachine(FactorAnalysisBase): r_U, em_iterations=10, relevance_factor=4.0, - seed=0, + random_state=0, ubm=None, ubm_kwargs=None, **kwargs, @@ -1333,7 +1340,7 @@ class ISVMachine(FactorAnalysisBase): r_U=r_U, relevance_factor=relevance_factor, em_iterations=em_iterations, - seed=seed, + random_state=random_state, ubm=ubm, ubm_kwargs=ubm_kwargs, **kwargs, @@ -1447,7 +1454,7 @@ class ISVMachine(FactorAnalysisBase): return self def transform(self, X): - ubm_projected_X = self.ubm.transform(X) + ubm_projected_X = self.ubm.acc_stats(X) return self.estimate_ux(ubm_projected_X) def enroll_using_stats(self, X, iterations=1): @@ -1519,7 +1526,7 @@ class ISVMachine(FactorAnalysisBase): z """ - return self.enroll_using_stats([self.ubm.transform(X)], iterations) + return self.enroll_using_stats([self.ubm.acc_stats(X)], iterations) def score_using_stats(self, latent_z, data): """ @@ -1584,8 +1591,8 @@ class JFAMachine(FactorAnalysisBase): relevance_factor: float Factor analysis relevance factor - seed: int - Seed for the random number generator + random_state: int + random_state for the random number generator """ @@ -1595,7 +1602,7 @@ class JFAMachine(FactorAnalysisBase): r_V, em_iterations=10, relevance_factor=4.0, - seed=0, + random_state=0, ubm=None, ubm_kwargs=None, **kwargs, @@ -1606,7 +1613,7 @@ class JFAMachine(FactorAnalysisBase): r_V=r_V, relevance_factor=relevance_factor, em_iterations=em_iterations, - seed=seed, + random_state=random_state, ubm_kwargs=ubm_kwargs, **kwargs, ) @@ -2059,7 +2066,7 @@ class JFAMachine(FactorAnalysisBase): z """ - return self.enroll_using_stats([self.ubm.transform(X)], iterations) + return self.enroll_using_stats([self.ubm.acc_stats(X)], iterations) def fit_using_stats(self, X, y): """ diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 25159f17aaa95f68f1795c6306f9ed57a9181b71..988ef3a5b512bb6b4075e6886816f96f22b03edb 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -850,10 +850,15 @@ class GMMMachine(BaseEstimator): ) return self - def transform(self, X): + def acc_stats(self, X): """Returns the statistics for `X`.""" + # we need this because sometimes the transform function gets overridden return e_step(data=X, machine=self) + def transform(self, X): + """Returns the statistics for `X`.""" + return self.acc_stats(X) + def stats_per_sample(self, X): return [e_step(data=xx, machine=self) for xx in X] diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py index d2399b26bfaa4f0711c67bea7c34340a8921b5a7..596a6a9fe2b75fa52b062e6286d4589e5ccff36d 100644 --- a/bob/learn/em/test/test_factor_analysis.py +++ b/bob/learn/em/test/test_factor_analysis.py @@ -557,7 +557,7 @@ def test_ISV_JFA_fit(): relevance_factor=4, em_iterations=50, ubm_kwargs=ubm_kwargs, - seed=10, + random_state=10, ) if machine_type == "isv":