Skip to content
Snippets Groups Projects
Commit 40cb9aea authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Dasking estimators

parent 735ad313
No related branches found
No related tags found
2 merge requests!185Wrappers and aggregators,!180[dask] Preparing bob.bio.base for dask pipelines
...@@ -18,6 +18,7 @@ def wrap_transform_bob( ...@@ -18,6 +18,7 @@ def wrap_transform_bob(
dir_name, dir_name,
fit_extra_arguments=(("y", "subject"),), fit_extra_arguments=(("y", "subject"),),
transform_extra_arguments=None, transform_extra_arguments=None,
dask_it=False
): ):
""" """
Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor` Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor`
...@@ -40,12 +41,14 @@ def wrap_transform_bob( ...@@ -40,12 +41,14 @@ def wrap_transform_bob(
transform_extra_arguments: [tuple] transform_extra_arguments: [tuple]
Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
dask_it: bool
If True, the transformer will be a dask graph
""" """
if isinstance(bob_object, Preprocessor): if isinstance(bob_object, Preprocessor):
preprocessor_transformer = PreprocessorTransformer(bob_object) preprocessor_transformer = PreprocessorTransformer(bob_object)
return wrap_preprocessor( transformer = wrap_preprocessor(
preprocessor_transformer, preprocessor_transformer,
features_dir=os.path.join(dir_name, "preprocessor"), features_dir=os.path.join(dir_name, "preprocessor"),
transform_extra_arguments=transform_extra_arguments, transform_extra_arguments=transform_extra_arguments,
...@@ -53,7 +56,7 @@ def wrap_transform_bob( ...@@ -53,7 +56,7 @@ def wrap_transform_bob(
elif isinstance(bob_object, Extractor): elif isinstance(bob_object, Extractor):
extractor_transformer = ExtractorTransformer(bob_object) extractor_transformer = ExtractorTransformer(bob_object)
path = os.path.join(dir_name, "extractor") path = os.path.join(dir_name, "extractor")
return wrap_extractor( transformer = wrap_extractor(
extractor_transformer, extractor_transformer,
features_dir=path, features_dir=path,
model_path=os.path.join(path, "extractor.pkl"), model_path=os.path.join(path, "extractor.pkl"),
...@@ -65,7 +68,7 @@ def wrap_transform_bob( ...@@ -65,7 +68,7 @@ def wrap_transform_bob(
algorithm_transformer = AlgorithmTransformer( algorithm_transformer = AlgorithmTransformer(
bob_object, projector_file=os.path.join(path, "Projector.hdf5") bob_object, projector_file=os.path.join(path, "Projector.hdf5")
) )
return wrap_algorithm( transformer = wrap_algorithm(
algorithm_transformer, algorithm_transformer,
features_dir=path, features_dir=path,
model_path=os.path.join(path, "algorithm.pkl"), model_path=os.path.join(path, "algorithm.pkl"),
...@@ -77,6 +80,11 @@ def wrap_transform_bob( ...@@ -77,6 +80,11 @@ def wrap_transform_bob(
"`bob_object` should be an instance of `Preprocessor`, `Extractor` and `Algorithm`" "`bob_object` should be an instance of `Preprocessor`, `Extractor` and `Algorithm`"
) )
if dask_it:
transformer = mario.wrap(["dask"], transformer)
return transformer
def wrap_preprocessor( def wrap_preprocessor(
preprocessor_transformer, features_dir=None, transform_extra_arguments=None, preprocessor_transformer, features_dir=None, transform_extra_arguments=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment