Skip to content
Snippets Groups Projects
Commit 0f295cbb authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed mario.wrap([dask]). It was not possible to set the npartitions kwarg

parent 1a164338
Branches
Tags
1 merge request!29Fixed mario.wrap([dask]). It was not possible to set the npartitions kwarg
Pipeline #39597 passed
......@@ -288,7 +288,7 @@ def test_checkpoint_fit_transform_pipeline():
transformer = ("1", _build_transformer(d, 1))
pipeline = Pipeline([fitter, transformer])
if dask_enabled:
pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")])
pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1)
pipeline = pipeline.fit(samples)
tags = mario.dask_tags(pipeline)
......
......@@ -463,13 +463,15 @@ def wrap(bases, estimator=None, **kwargs):
trans, leftover = _wrap(trans, **kwargs)
estimator.steps[idx] = (name, trans)
# Using the leftovers as new kwargs to be consumed further
kwargs = leftover
# if being wrapped with DaskWrapper, add ToDaskBag to the steps
if DaskWrapper in bases:
valid_params = ToDaskBag._get_param_names()
params = {k: kwargs.pop(k) for k in valid_params if k in kwargs}
dask_bag = ToDaskBag(**params)
estimator.steps.insert(0, ("ToDaskBag", dask_bag))
leftover = kwargs
else:
estimator, leftover = _wrap(estimator, **kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment