diff --git a/bob/bio/base/wrappers.py b/bob/bio/base/wrappers.py index 37499e3c320edcedf7bc71786fbe6bb91c06508f..cee892007d424f08f46519bd41b5a36b929f46cf 100644 --- a/bob/bio/base/wrappers.py +++ b/bob/bio/base/wrappers.py @@ -18,6 +18,7 @@ def wrap_transform_bob( dir_name, fit_extra_arguments=(("y", "subject"),), transform_extra_arguments=None, + dask_it=False ): """ Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor` @@ -40,12 +41,14 @@ def wrap_transform_bob( transform_extra_arguments: [tuple] Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` + dask_it: bool + If True, the transformer will be a dask graph """ if isinstance(bob_object, Preprocessor): preprocessor_transformer = PreprocessorTransformer(bob_object) - return wrap_preprocessor( + transformer = wrap_preprocessor( preprocessor_transformer, features_dir=os.path.join(dir_name, "preprocessor"), transform_extra_arguments=transform_extra_arguments, @@ -53,7 +56,7 @@ def wrap_transform_bob( elif isinstance(bob_object, Extractor): extractor_transformer = ExtractorTransformer(bob_object) path = os.path.join(dir_name, "extractor") - return wrap_extractor( + transformer = wrap_extractor( extractor_transformer, features_dir=path, model_path=os.path.join(path, "extractor.pkl"), @@ -65,7 +68,7 @@ def wrap_transform_bob( algorithm_transformer = AlgorithmTransformer( bob_object, projector_file=os.path.join(path, "Projector.hdf5") ) - return wrap_algorithm( + transformer = wrap_algorithm( algorithm_transformer, features_dir=path, model_path=os.path.join(path, "algorithm.pkl"), @@ -77,6 +80,11 @@ def wrap_transform_bob( "`bob_object` should be an instance of `Preprocessor`, `Extractor` and `Algorithm`" ) + if dask_it: + transformer = mario.wrap(["dask"], transformer) + + return transformer + def wrap_preprocessor( preprocessor_transformer, features_dir=None, transform_extra_arguments=None,