Fixed mario.wrap([dask]). It was not possible to set the npartitions kwarg

parent 1a164338
Pipeline #39597 passed with stage
in 8 minutes and 49 seconds
...@@ -288,7 +288,7 @@ def test_checkpoint_fit_transform_pipeline(): ...@@ -288,7 +288,7 @@ def test_checkpoint_fit_transform_pipeline():
transformer = ("1", _build_transformer(d, 1)) transformer = ("1", _build_transformer(d, 1))
pipeline = Pipeline([fitter, transformer]) pipeline = Pipeline([fitter, transformer])
if dask_enabled: if dask_enabled:
pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")]) pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1)
pipeline = pipeline.fit(samples) pipeline = pipeline.fit(samples)
tags = mario.dask_tags(pipeline) tags = mario.dask_tags(pipeline)
......
...@@ -463,13 +463,15 @@ def wrap(bases, estimator=None, **kwargs): ...@@ -463,13 +463,15 @@ def wrap(bases, estimator=None, **kwargs):
trans, leftover = _wrap(trans, **kwargs) trans, leftover = _wrap(trans, **kwargs)
estimator.steps[idx] = (name, trans) estimator.steps[idx] = (name, trans)
# Using the leftovers as new kwargs to be consumed further
kwargs = leftover
# if being wrapped with DaskWrapper, add ToDaskBag to the steps # if being wrapped with DaskWrapper, add ToDaskBag to the steps
if DaskWrapper in bases: if DaskWrapper in bases:
valid_params = ToDaskBag._get_param_names() valid_params = ToDaskBag._get_param_names()
params = {k: kwargs.pop(k) for k in valid_params if k in kwargs} params = {k: kwargs.pop(k) for k in valid_params if k in kwargs}
dask_bag = ToDaskBag(**params) dask_bag = ToDaskBag(**params)
estimator.steps.insert(0, ("ToDaskBag", dask_bag)) estimator.steps.insert(0, ("ToDaskBag", dask_bag))
leftover = kwargs
else: else:
estimator, leftover = _wrap(estimator, **kwargs) estimator, leftover = _wrap(estimator, **kwargs)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment