diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index f9b39cef2a344e552a190dff26929706a5c04b5e..4d69d7cb4dba85ff67469967a3265dfc8981599c 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -1320,6 +1320,38 @@ class FactorAnalysisBase(BaseEstimator): """ return self.enroll([self.ubm.acc_stats(X)]) + def _prepare_dask_input(self, X, y): + """Perpare the input for the fit method""" + logger.info( + "Rechunking bag of stats to delayed list of stats per class. If your worker runs " + "out of memory in this training step, you have to use workers with more memory." + ) + # optimize the graph of X and y and also persist X before computing y + X, y = dask.optimize(X, y) + X = X.persist() + y = np.array(dask.compute(y)[0]) + n_classes = len(set(y)) + + # X is a list Stats in a dask bag chunked randomly. We want X to be a + # list of dask delayed objects where each delayed object is a list of + # stats per class + def _len(stats): + return len(stats) + + lengths = X.map_partitions(_len).compute() + delayeds = X.to_delayed() + X, i = [[] for _ in range(n_classes)], 0 + for length_, delayed_stats_list in zip(lengths, delayeds): + delayed_stats_list._length = length_ + for delayed_stat in delayed_stats_list: + class_id = y[i] + X[class_id].append(delayed_stat) + i += 1 + X = [dask.delayed(list)(stats).persist() for stats in X] + y = [y[y == class_id] for class_id in range(n_classes)] + + return X, y + class ISVMachine(FactorAnalysisBase): """ @@ -1443,6 +1475,9 @@ class ISVMachine(FactorAnalysisBase): Returns self. """ + if isinstance(X, dask.bag.Bag): + X, y = self._prepare_dask_input(X, y) + ( input_is_dask, n_classes, @@ -2082,6 +2117,8 @@ class JFAMachine(FactorAnalysisBase): Returns self. """ + if isinstance(X, dask.bag.Bag): + X, y = self._prepare_dask_input(X, y) ( input_is_dask,