Finished resource tags

parent 0ce7bce1
Pipeline #38092 passed with stage
in 8 minutes and 42 seconds
......@@ -12,29 +12,93 @@ from sklearn.pipeline import Pipeline
from dask import delayed
import dask.bag
def dask_it(o):
def dask_it(o, fit_tag=None, transform_tag=None):
"""
Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base with
Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` with
:py:class`DaskEstimatorMixin`
Parameters
----------
o: :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base`
Any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` to be dask mixed
fit_tag: list(tuple()) or "str"
Tag the `fit` method. This is useful to tag dask tasks to run in specific workers https://distributed.dask.org/en/latest/resources.html
If `o` is :py:class:`sklearn.pipeline.Pipeline`, this parameter should contain a list of tuples
containing the pipeline.step index and the `str` tag for `fit`.
If `o` is :py:class:`sklearn.estimator.Base`, this parameter should contain just the tag for `fit`
transform_tag: list(tuple()) or "str"
Tag the `fit` method. This is useful to tag dask tasks to run in specific workers https://distributed.dask.org/en/latest/resources.html
If `o` is :py:class:`sklearn.pipeline.Pipeline`, this parameter should contain a list of tuples
containing the pipeline.step index and the `str` tag for `transform`.
If `o` is :py:class:`sklearn.estimator.Base`, this parameter should contain just the tag for `transform`
Examples
--------
Vanilla example
>>> pipeline = dask_it(pipeline) # Take some pipeline and make the methods `fit`and `transform` run over dask
>>> pipeline.fit(samples).compute()
In this example we will "mark" the fit method with a particular tag
Hence, we can set the `dask.delayed.compute` method to place some
delayeds to be executed in particular resources
>>> pipeline = dask_it(pipeline, fit_tag=[(1, "GPU")]) # Take some pipeline and make the methods `fit`and `transform` run over dask
>>> fit = pipeline.fit(samples)
>>> fit.compute(resources=pipeline.dask_tags())
Taging estimator
>>> estimator = dask_it(estimator)
>>> transf = estimator.transform(samples)
>>> transf.compute(resources=estimator.dask_tags())
"""
def _fetch_resource_tape(o):
def _fetch_resource_tape(self):
"""
Get all the resources take
"""
resource_tape = dict()
if isinstance(o, Pipeline):
for o in range(1,len(o.steps)):
resource_tape += o.resource_tape
resource_tags = dict()
if isinstance(self, Pipeline):
for i in range(1,len(self.steps)):
resource_tags.update(o[i].resource_tags)
else:
resource_tags.update(self.resource_tags)
return resource_tape
return resource_tags
if isinstance(o, Pipeline):
#Adding a daskbag in the tail of the pipeline
o.steps.insert(0, ('0', DaskBagMixin()))
# Patching dask_resources
dasked = mix_me_up(DaskEstimatorMixin, o)
#dasked.dask_resources = _fetch_resource_tape(o)
# Tagging each element in a pipeline
if isinstance(o, Pipeline):
# Tagging each element for fitting and transforming
if fit_tag is not None:
for t in fit_tag:
o.steps[t[0]][1].fit_tag = t[1]
if transform_tag is not None:
for t in transform_tag:
o.steps[t[0]][1].transform_tag = t[1]
else:
dasked.fit_tag = fit_tag
dasked.transform_tag = transform_tag
# Bounding the method
dasked.dask_tags = types.MethodType( _fetch_resource_tape, dasked )
return dasked
......@@ -287,17 +351,17 @@ class DaskEstimatorMixin:
"""
def __init__(self, fit_resource=None, transform_resource=None, **kwargs):
def __init__(self, fit_tag=None, transform_tag=None, **kwargs):
super().__init__(**kwargs)
self._dask_state = self
self.resource_tape = dict()
self.fit_resource = fit_resource
self.transform_resource = transform_resource
self.resource_tags = dict()
self.fit_tag = fit_tag
self.transform_tag = transform_tag
def fit(self, X, y=None, **fit_params):
self._dask_state = delayed(super().fit)(X, y, **fit_params)
if self.fit_resource is not None:
self.resource_tape[self._dask_state] = self.fit_resource
if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self.fit_tag
return self
......@@ -306,8 +370,8 @@ class DaskEstimatorMixin:
return super(DaskEstimatorMixin, dask_state).transform(X_line)
map_partitions = X.map_partitions(_transf, self._dask_state)
if self.transform_resource is not None:
self.resource_tape[map_partitions] = self.transform_resource
if self.transform_tag is not None:
self.resource_tags[map_partitions] = self.transform_tag
return map_partitions
......
......@@ -280,9 +280,12 @@ def test_checkpoint_fit_transform_pipeline():
fitter = ("0", _build_estimator(d, 0))
transformer = ("1", _build_transformer(d, 1))
pipeline = Pipeline([fitter, transformer])
if dask_enabled:
pipeline = dask_it(pipeline)
if dask_enabled:
pipeline = dask_it(pipeline, fit_tag=[(1, "GPU")])
pipeline = pipeline.fit(samples)
tags = pipeline.dask_tags()
assert len(tags) == 1
transformed_samples = pipeline.transform(samples_transform)
transformed_samples = transformed_samples.compute(
......@@ -348,6 +351,7 @@ def test_dask_checkpoint_transform_pipeline():
samples_transform = [Sample(data, key=str(i)) for i, data in enumerate(X)]
with tempfile.TemporaryDirectory() as d:
bag_transformer = DaskBagMixin()
estimator = dask_it(_build_transformer(d, 0))
estimator = dask_it(_build_transformer(d, 0), transform_tag="CPU")
X_tr = estimator.transform(bag_transformer.transform(samples_transform))
assert len(estimator.dask_tags()) == 1
assert len(X_tr.compute(scheduler="single-threaded")) == 10
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment