diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 2d5b0b9340e69dda2f9f585ebc9c16c83d006d64..ae7bc395a65c04d174b4d826522a8e46f9e11433 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -857,10 +857,24 @@ class GMMMachine(BaseEstimator): def transform(self, X, **kwargs): """Returns the statistics for `X`.""" - return e_step( - data=X, - machine=self, - ) + input_is_dask, X = check_and_persist_dask_input(X) + + if input_is_dask: + stats = [ + dask.delayed(e_step)( + data=xx, + machine=self, + ) + for xx in X + ] + stats = functools.reduce(operator.iadd, stats) + stats = stats.compute() + else: + stats = e_step( + data=X, + machine=self, + ) + return stats def _more_tags(self): return {