diff --git a/bob/pipelines/tests/test_wrappers.py b/bob/pipelines/tests/test_wrappers.py index 8acda5dd821c7133ebb83369d2cfd8e18a7c2f07..0d2ffab4b314f0a291745ee06e387e7fca947f74 100644 --- a/bob/pipelines/tests/test_wrappers.py +++ b/bob/pipelines/tests/test_wrappers.py @@ -288,7 +288,7 @@ def test_checkpoint_fit_transform_pipeline(): transformer = ("1", _build_transformer(d, 1)) pipeline = Pipeline([fitter, transformer]) if dask_enabled: - pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")]) + pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1) pipeline = pipeline.fit(samples) tags = mario.dask_tags(pipeline) diff --git a/bob/pipelines/wrappers.py b/bob/pipelines/wrappers.py index f551af4c3e1816c06ea1efd380d6451308d62054..24c1578f75f8b6b76f9d77544f1d2182f6fe8ed6 100644 --- a/bob/pipelines/wrappers.py +++ b/bob/pipelines/wrappers.py @@ -463,13 +463,15 @@ def wrap(bases, estimator=None, **kwargs): trans, leftover = _wrap(trans, **kwargs) estimator.steps[idx] = (name, trans) + # Using the leftovers as new kwargs to be consumed further + kwargs = leftover # if being wrapped with DaskWrapper, add ToDaskBag to the steps if DaskWrapper in bases: valid_params = ToDaskBag._get_param_names() params = {k: kwargs.pop(k) for k in valid_params if k in kwargs} dask_bag = ToDaskBag(**params) estimator.steps.insert(0, ("ToDaskBag", dask_bag)) - + leftover = kwargs else: estimator, leftover = _wrap(estimator, **kwargs)