Skip to content
Snippets Groups Projects

[dask][sge] Multiqueue updates

Merged Tiago de Freitas Pereira requested to merge multiqueue into master
3 files
+ 12
8
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -25,7 +25,7 @@ class MyFitTranformer(TransformerMixin, BaseEstimator):
def __init__(self):
self._fit_model = None
def transform(self, X):
def transform(self, X, metadata=None):
# Transform `X`
return [x @ self._fit_model for x in X]
@@ -40,7 +40,7 @@ X = numpy.zeros((2, 2))
X_as_sample = [Sample(X, key=str(i), metadata=1) for i in range(10)]
# Building an arbitrary pipeline
model_path = "~/dask_tmp"
model_path = "./dask_tmp"
os.makedirs(model_path, exist_ok=True)
pipeline = make_pipeline(MyTransformer(), MyFitTranformer())
@@ -48,13 +48,15 @@ pipeline = make_pipeline(MyTransformer(), MyFitTranformer())
pipeline = bob.pipelines.wrap(
["sample", "checkpoint", "dask"],
pipeline,
model_path=model_path,
model_path=os.path.join(model_path, "model.pickle"),
features_dir=model_path,
transform_extra_arguments=(("metadata", "metadata"),),
)
# Create a dask graph from a pipeline
# Run the task graph in the local computer in a single tread
X_transformed = pipeline.fit_transform(X_as_sample).compute(scheduler="single-threaded")
import shutil
shutil.rmtree(model_path)
Loading