From 40cb9aeafc93fccdcc8c3347bc6c645ee1e9e01f Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 1 May 2020 15:48:34 +0200
Subject: [PATCH] Dasking estimators

---
 bob/bio/base/wrappers.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/bob/bio/base/wrappers.py b/bob/bio/base/wrappers.py
index 37499e3c..cee89200 100644
--- a/bob/bio/base/wrappers.py
+++ b/bob/bio/base/wrappers.py
@@ -18,6 +18,7 @@ def wrap_transform_bob(
     dir_name,
     fit_extra_arguments=(("y", "subject"),),
     transform_extra_arguments=None,
+    dask_it=False
 ):
     """
     Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor`
@@ -40,12 +41,14 @@ def wrap_transform_bob(
     transform_extra_arguments: [tuple]
         Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
 
+    dask_it: bool
+        If True, the transformer will be a dask graph
 
     """
 
     if isinstance(bob_object, Preprocessor):
         preprocessor_transformer = PreprocessorTransformer(bob_object)
-        return wrap_preprocessor(
+        transformer = wrap_preprocessor(
             preprocessor_transformer,
             features_dir=os.path.join(dir_name, "preprocessor"),
             transform_extra_arguments=transform_extra_arguments,
@@ -53,7 +56,7 @@ def wrap_transform_bob(
     elif isinstance(bob_object, Extractor):
         extractor_transformer = ExtractorTransformer(bob_object)
         path = os.path.join(dir_name, "extractor")
-        return wrap_extractor(
+        transformer =  wrap_extractor(
             extractor_transformer,
             features_dir=path,
             model_path=os.path.join(path, "extractor.pkl"),
@@ -65,7 +68,7 @@ def wrap_transform_bob(
         algorithm_transformer = AlgorithmTransformer(
             bob_object, projector_file=os.path.join(path, "Projector.hdf5")
         )
-        return wrap_algorithm(
+        transformer = wrap_algorithm(
             algorithm_transformer,
             features_dir=path,
             model_path=os.path.join(path, "algorithm.pkl"),
@@ -77,6 +80,11 @@ def wrap_transform_bob(
             "`bob_object` should be an instance of `Preprocessor`, `Extractor` and `Algorithm`"
         )
 
+    if dask_it:
+        transformer = mario.wrap(["dask"], transformer)
+
+    return transformer
+
 
 def wrap_preprocessor(
     preprocessor_transformer, features_dir=None, transform_extra_arguments=None,
-- 
GitLab