diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index b65da8263eb6190b0a55e038ac7837cfba347ebd..ea66294e90fb9a93f35c86a75a35643eba2385f8 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -12,6 +12,7 @@ import numpy as np from dask.array.core import Array from dask.delayed import Delayed from sklearn.base import BaseEstimator +from sklearn.utils import check_consistent_length from sklearn.utils.multiclass import unique_labels from .gmm import GMMMachine @@ -1260,8 +1261,10 @@ class FactorAnalysisBase(BaseEstimator): return self.score_using_stats(model, self.ubm.transform(data)) def fit(self, X, y): + input_is_dask, X = check_and_persist_dask_input(X, persist=False) y = np.squeeze(np.asarray(y)) + check_consistent_length(X, y) self.initialize(X)