Skip to content
Snippets Groups Projects

Add support for fitting on dask bags in ISV/JFA

Merged Amir MOHAMMADI requested to merge dask-bag into master
1 file
+ 37
0
Compare changes
  • Side-by-side
  • Inline
@@ -1320,6 +1320,38 @@ class FactorAnalysisBase(BaseEstimator):
"""
return self.enroll([self.ubm.acc_stats(X)])
def _prepare_dask_input(self, X, y):
"""Perpare the input for the fit method"""
logger.info(
"Rechunking bag of stats to delayed list of stats per class. If your worker runs "
"out of memory in this training step, you have to use workers with more memory."
)
# optimize the graph of X and y and also persist X before computing y
X, y = dask.optimize(X, y)
X = X.persist()
y = np.array(dask.compute(y)[0])
n_classes = len(set(y))
# X is a list Stats in a dask bag chunked randomly. We want X to be a
# list of dask delayed objects where each delayed object is a list of
# stats per class
def _len(stats):
return len(stats)
lengths = X.map_partitions(_len).compute()
delayeds = X.to_delayed()
X, i = [[] for _ in range(n_classes)], 0
for length_, delayed_stats_list in zip(lengths, delayeds):
delayed_stats_list._length = length_
for delayed_stat in delayed_stats_list:
class_id = y[i]
X[class_id].append(delayed_stat)
i += 1
X = [dask.delayed(list)(stats).persist() for stats in X]
y = [y[y == class_id] for class_id in range(n_classes)]
return X, y
class ISVMachine(FactorAnalysisBase):
"""
@@ -1443,6 +1475,9 @@ class ISVMachine(FactorAnalysisBase):
Returns self.
"""
if isinstance(X, dask.bag.Bag):
X, y = self._prepare_dask_input(X, y)
(
input_is_dask,
n_classes,
@@ -2082,6 +2117,8 @@ class JFAMachine(FactorAnalysisBase):
Returns self.
"""
if isinstance(X, dask.bag.Bag):
X, y = self._prepare_dask_input(X, y)
(
input_is_dask,
Loading