From b1cd7bc6efa0c26b961379254e475aa008d5e024 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Tue, 26 Apr 2022 20:38:50 +0200 Subject: [PATCH] small fixes --- bob/learn/em/factor_analysis.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index 779068b..cd9a769 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -230,12 +230,13 @@ class FactorAnalysisBase(BaseEstimator): """ if self.ubm is None: - logger.info("FA: Creating a new GMMMachine and training it.") + logger.info("Creating a new GMMMachine and training it.") gmm_class = self.ubm_kwargs.pop("gmm_class", GMMMachine) self.ubm: GMMMachine = gmm_class(**self.ubm_kwargs) self.ubm.fit(X) if self.ubm._means is None: + logger.info("UBM means are None, training the UBM.") self.ubm.fit(X) # Initializing the state matrix @@ -1276,21 +1277,17 @@ class FactorAnalysisBase(BaseEstimator): class_indices = y == class_id X_new.append(X[class_indices]) y_new.append(y[class_indices]) - __import__("ipdb").set_trace() X, y = X_new, y_new + del X_new, y_new stats = [ dask.delayed(self.ubm.stats_per_sample)(xx).persist() for xx in X ] - # try: - # client = dask.distributed.Client.current() - # stats = client.scatter(stats) - # except ValueError: - # pass else: stats = self.ubm.stats_per_sample(X) + del X self.fit_using_stats(stats, y) return self -- GitLab