From 40cb9aeafc93fccdcc8c3347bc6c645ee1e9e01f Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Fri, 1 May 2020 15:48:34 +0200 Subject: [PATCH] Dasking estimators --- bob/bio/base/wrappers.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/bob/bio/base/wrappers.py b/bob/bio/base/wrappers.py index 37499e3c..cee89200 100644 --- a/bob/bio/base/wrappers.py +++ b/bob/bio/base/wrappers.py @@ -18,6 +18,7 @@ def wrap_transform_bob( dir_name, fit_extra_arguments=(("y", "subject"),), transform_extra_arguments=None, + dask_it=False ): """ Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor` @@ -40,12 +41,14 @@ def wrap_transform_bob( transform_extra_arguments: [tuple] Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + dask_it: bool + If True, the transformer will be a dask graph """ if isinstance(bob_object, Preprocessor): preprocessor_transformer = PreprocessorTransformer(bob_object) - return wrap_preprocessor( + transformer = wrap_preprocessor( preprocessor_transformer, features_dir=os.path.join(dir_name, "preprocessor"), transform_extra_arguments=transform_extra_arguments, @@ -53,7 +56,7 @@ def wrap_transform_bob( elif isinstance(bob_object, Extractor): extractor_transformer = ExtractorTransformer(bob_object) path = os.path.join(dir_name, "extractor") - return wrap_extractor( + transformer = wrap_extractor( extractor_transformer, features_dir=path, model_path=os.path.join(path, "extractor.pkl"), @@ -65,7 +68,7 @@ def wrap_transform_bob( algorithm_transformer = AlgorithmTransformer( bob_object, projector_file=os.path.join(path, "Projector.hdf5") ) - return wrap_algorithm( + transformer = wrap_algorithm( algorithm_transformer, features_dir=path, model_path=os.path.join(path, "algorithm.pkl"), @@ -77,6 +80,11 @@ def wrap_transform_bob( "`bob_object` should be an instance of `Preprocessor`, `Extractor` and `Algorithm`" ) + if dask_it: + transformer = mario.wrap(["dask"], transformer) + + return transformer + def wrap_preprocessor( preprocessor_transformer, features_dir=None, transform_extra_arguments=None, -- GitLab