Fixed multiqueue

parent 3caab253
......@@ -107,6 +107,34 @@ def get_max_jobs(queue_dict):
)
def get_resource_requirements(pipeline):
"""
Get the resource requirements to execute a graph.
This is useful when it's necessary get the dictionary mapping the dask delayed keys with
specific resource restrictions.
Check https://distributed.dask.org/en/latest/resources.html#resources-with-collections for more information
Parameters
----------
pipeline: :any:`sklearn.pipeline.Pipeline`
A :any:`sklearn.pipeline.Pipeline` wrapper with :any:`bob.pipelines.DaskWrapper`
Example
-------
>>> cluster = SGEMultipleQueuesCluster(sge_job_spec=Q_1DAY_GPU_SPEC) # doctest: +SKIP
>>> client = Client(cluster) # doctest: +SKIP
>>> from bob.pipelines.sge import get_resource_requirements # doctest: +SKIP
>>> resources = get_resource_requirements(pipeline) # doctest: +SKIP
>>> my_delayed_task.compute(scheduler=client, resources=resources) # doctest: +SKIP
"""
resources = dict()
for s in pipeline:
if hasattr(s, "resource_tags"):
resources.update(s.resource_tags)
return resources
class SGEMultipleQueuesCluster(JobQueueCluster):
"""Launch Dask jobs in the SGE cluster allowing the request of multiple
queues.
......@@ -265,30 +293,7 @@ class SGEMultipleQueuesCluster(JobQueueCluster):
# Here the goal is to wait 2 minutes before scaling down since
# it is very expensive to get jobs on the SGE grid
self.adapt(minimum=min_jobs, maximum=max_jobs, wait_count=5, interval=120)
def get_sge_resources(self):
"""
Get the available resources from `SGEMultipleQueuesCluster.sge_job_spec`.
This is useful when it's necessary to set the resource available for
`.compute` method.
Check https://distributed.dask.org/en/latest/resources.html#resources-with-collections for more information
Example
-------
>>> cluster = SGEMultipleQueuesCluster(sge_job_spec=Q_1DAY_GPU_SPEC) # doctest: +SKIP
>>> client = Client(cluster) # doctest: +SKIP
>>> resources = cluster.get_sge_resources() # doctest: +SKIP
>>> my_delayed_task.compute(scheduler=client, resources=resources) # doctest: +SKIP
"""
resources = [
list(self.sge_job_spec[k]["resources"].items())[0]
for k in self.sge_job_spec
if self.sge_job_spec[k]["resources"] != ""
]
return dict(resources)
self.adapt(minimum=min_jobs, maximum=max_jobs, wait_count=5, interval=10)
def _get_worker_spec_options(self, job_spec):
"""Craft a dask worker_spec to be used in the qsub command."""
......@@ -475,7 +480,7 @@ class SchedulerResourceRestriction(Scheduler):
allowed_failures=100
if rc.get("bob.pipelines.sge.allowed_failures") is None
else rc.get("bob.pipelines.sge.allowed_failures"),
synchronize_worker_interval="20s",
synchronize_worker_interval="10s",
*args,
**kwargs,
)
......
......@@ -10,7 +10,7 @@ QUEUE_DEFAULT = {
"io_big": False,
"resource_spec": "",
"max_jobs": 96,
"resources": "",
"resources": {"default": 1},
},
"q_1week": {
"queue": "q_1week",
......
......@@ -304,9 +304,7 @@ def test_checkpoint_fit_transform_pipeline():
transformer = ("1", _build_transformer(d, 1))
pipeline = Pipeline([fitter, transformer])
if dask_enabled:
pipeline = mario.wrap(
["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1
)
pipeline = mario.wrap(["dask"], pipeline, fit_tag="GPU", npartitions=1)
pipeline = pipeline.fit(samples)
tags = mario.dask_tags(pipeline)
......
......@@ -138,8 +138,7 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name),
parent=sset,
self._samples_transform(sset.samples, method_name), parent=sset,
)
for sset in samples
]
......@@ -416,11 +415,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
"""
def __init__(
self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
self, estimator, fit_tag=None, transform_tag=None, **kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
......@@ -430,7 +425,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
self.transform_tag = transform_tag
def _make_dask_resource_tag(self, tag):
return [(1, tag)]
return {tag: 1}
def _dask_transform(self, X, method_name):
graph_name = f"{_frmt(self)}.{method_name}"
......@@ -442,10 +437,10 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs
_transf.__name__ = graph_name
map_partitions = X.map_partitions(_transf, self._dask_state)
if self.transform_tag is not None:
self.resource_tags[map_partitions] = self._make_dask_resource_tag(
self.transform_tag
)
if self.transform_tag:
self.resource_tags[
tuple(map_partitions.dask.keys())
] = self._make_dask_resource_tag(self.transform_tag)
return map_partitions
......@@ -483,15 +478,16 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(
X,
y,
)
self._dask_state = delayed(_fit)(X, y)
if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self._make_dask_resource_tag(
self.fit_tag
)
from dask import core
# If you do `delayed(_fit)(X, y)`, two tasks are generated;
# the `finlize-TASK` and `TASK`. With this, we make sure
# that the two are annotated
self.resource_tags[
tuple([f"{k}{str(self._dask_state.key)}" for k in ["", "finalize-"]])
] = self._make_dask_resource_tag(self.fit_tag)
return self
......
......@@ -5,7 +5,10 @@ from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.pipeline import make_pipeline
from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster
from bob.pipelines.distributed.sge import (
SGEMultipleQueuesCluster,
get_resource_requirements,
)
from bob.pipelines.sample import Sample
import bob.pipelines
......@@ -60,12 +63,12 @@ pipeline = bob.pipelines.wrap(
# Creating my cluster obj.
cluster = SGEMultipleQueuesCluster()
client = Client(cluster) # Creating the scheduler
resources = get_resource_requirements(pipeline)
# Run the task graph in the local computer in a single tread
# NOTE THAT resources is set in .compute
X_transformed = pipeline.fit_transform(X_as_sample).compute(
scheduler=client, resources=cluster.get_sge_resources()
scheduler=client, resources=resources
)
import shutil
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment