From 0f295cbb2cdd9146e9fe76b235ec7665c70d09fd Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Fri, 1 May 2020 10:52:12 +0200 Subject: [PATCH] Fixed mario.wrap([dask]). It was not possible to set the npartitions kwarg --- bob/pipelines/tests/test_wrappers.py | 2 +- bob/pipelines/wrappers.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bob/pipelines/tests/test_wrappers.py b/bob/pipelines/tests/test_wrappers.py index 8acda5d..0d2ffab 100644 --- a/bob/pipelines/tests/test_wrappers.py +++ b/bob/pipelines/tests/test_wrappers.py @@ -288,7 +288,7 @@ def test_checkpoint_fit_transform_pipeline(): transformer = ("1", _build_transformer(d, 1)) pipeline = Pipeline([fitter, transformer]) if dask_enabled: - pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")]) + pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1) pipeline = pipeline.fit(samples) tags = mario.dask_tags(pipeline) diff --git a/bob/pipelines/wrappers.py b/bob/pipelines/wrappers.py index f551af4..24c1578 100644 --- a/bob/pipelines/wrappers.py +++ b/bob/pipelines/wrappers.py @@ -463,13 +463,15 @@ def wrap(bases, estimator=None, **kwargs): trans, leftover = _wrap(trans, **kwargs) estimator.steps[idx] = (name, trans) + # Using the leftovers as new kwargs to be consumed further + kwargs = leftover # if being wrapped with DaskWrapper, add ToDaskBag to the steps if DaskWrapper in bases: valid_params = ToDaskBag._get_param_names() params = {k: kwargs.pop(k) for k in valid_params if k in kwargs} dask_bag = ToDaskBag(**params) estimator.steps.insert(0, ("ToDaskBag", dask_bag)) - + leftover = kwargs else: estimator, leftover = _wrap(estimator, **kwargs) -- GitLab