Fixed mario.wrap([dask]). It was not possible to set the npartitions kwarg

parent 1a164338
Pipeline #39597 passed with stage
in 8 minutes and 49 seconds
......@@ -288,7 +288,7 @@ def test_checkpoint_fit_transform_pipeline():
transformer = ("1", _build_transformer(d, 1))
pipeline = Pipeline([fitter, transformer])
if dask_enabled:
pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")])
pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1)
pipeline =
tags = mario.dask_tags(pipeline)
......@@ -463,13 +463,15 @@ def wrap(bases, estimator=None, **kwargs):
trans, leftover = _wrap(trans, **kwargs)
estimator.steps[idx] = (name, trans)
# Using the leftovers as new kwargs to be consumed further
kwargs = leftover
# if being wrapped with DaskWrapper, add ToDaskBag to the steps
if DaskWrapper in bases:
valid_params = ToDaskBag._get_param_names()
params = {k: kwargs.pop(k) for k in valid_params if k in kwargs}
dask_bag = ToDaskBag(**params)
estimator.steps.insert(0, ("ToDaskBag", dask_bag))
leftover = kwargs
estimator, leftover = _wrap(estimator, **kwargs)
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment