diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index 779068ba07708df1afdcf51e6c867d3b181fecbf..cd9a769553b5a41368aa11ece53c35045a31e1a4 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -230,12 +230,13 @@ class FactorAnalysisBase(BaseEstimator): """ if self.ubm is None: - logger.info("FA: Creating a new GMMMachine and training it.") + logger.info("Creating a new GMMMachine and training it.") gmm_class = self.ubm_kwargs.pop("gmm_class", GMMMachine) self.ubm: GMMMachine = gmm_class(**self.ubm_kwargs) self.ubm.fit(X) if self.ubm._means is None: + logger.info("UBM means are None, training the UBM.") self.ubm.fit(X) # Initializing the state matrix @@ -1276,21 +1277,17 @@ class FactorAnalysisBase(BaseEstimator): class_indices = y == class_id X_new.append(X[class_indices]) y_new.append(y[class_indices]) - __import__("ipdb").set_trace() X, y = X_new, y_new + del X_new, y_new stats = [ dask.delayed(self.ubm.stats_per_sample)(xx).persist() for xx in X ] - # try: - # client = dask.distributed.Client.current() - # stats = client.scatter(stats) - # except ValueError: - # pass else: stats = self.ubm.stats_per_sample(X) + del X self.fit_using_stats(stats, y) return self