Skip to content
Snippets Groups Projects
Commit aa7c4c78 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'dask-bag' into 'master'

Add support for fitting on dask bags in ISV/JFA

See merge request !58
parents 81e7e8a9 06b170c2
No related branches found
No related tags found
1 merge request!58Add support for fitting on dask bags in ISV/JFA
Pipeline #61034 passed
...@@ -1320,6 +1320,38 @@ class FactorAnalysisBase(BaseEstimator): ...@@ -1320,6 +1320,38 @@ class FactorAnalysisBase(BaseEstimator):
""" """
return self.enroll([self.ubm.acc_stats(X)]) return self.enroll([self.ubm.acc_stats(X)])
def _prepare_dask_input(self, X, y):
"""Perpare the input for the fit method"""
logger.info(
"Rechunking bag of stats to delayed list of stats per class. If your worker runs "
"out of memory in this training step, you have to use workers with more memory."
)
# optimize the graph of X and y and also persist X before computing y
X, y = dask.optimize(X, y)
X = X.persist()
y = np.array(dask.compute(y)[0])
n_classes = len(set(y))
# X is a list Stats in a dask bag chunked randomly. We want X to be a
# list of dask delayed objects where each delayed object is a list of
# stats per class
def _len(stats):
return len(stats)
lengths = X.map_partitions(_len).compute()
delayeds = X.to_delayed()
X, i = [[] for _ in range(n_classes)], 0
for length_, delayed_stats_list in zip(lengths, delayeds):
delayed_stats_list._length = length_
for delayed_stat in delayed_stats_list:
class_id = y[i]
X[class_id].append(delayed_stat)
i += 1
X = [dask.delayed(list)(stats).persist() for stats in X]
y = [y[y == class_id] for class_id in range(n_classes)]
return X, y
class ISVMachine(FactorAnalysisBase): class ISVMachine(FactorAnalysisBase):
""" """
...@@ -1443,6 +1475,9 @@ class ISVMachine(FactorAnalysisBase): ...@@ -1443,6 +1475,9 @@ class ISVMachine(FactorAnalysisBase):
Returns self. Returns self.
""" """
if isinstance(X, dask.bag.Bag):
X, y = self._prepare_dask_input(X, y)
( (
input_is_dask, input_is_dask,
n_classes, n_classes,
...@@ -2082,6 +2117,8 @@ class JFAMachine(FactorAnalysisBase): ...@@ -2082,6 +2117,8 @@ class JFAMachine(FactorAnalysisBase):
Returns self. Returns self.
""" """
if isinstance(X, dask.bag.Bag):
X, y = self._prepare_dask_input(X, y)
( (
input_is_dask, input_is_dask,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment