From 7fc43e9ec716ccce1a5d1f9a1af4a6bde9b3ed4c Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Tue, 23 Jun 2020 14:56:47 +0200
Subject: [PATCH] Possible split between bioalgorithm_dir and transform dir

---
 .../pipelines/vanilla_biometrics/wrappers.py  | 26 ++++++++++++++-----
 1 file changed, 20 insertions(+), 6 deletions(-)

diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
index f4c47286..b577cea0 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
@@ -240,7 +240,9 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
         pipeline.vanilla_biometrics_pipeline = dask_vanilla_biometrics(
             pipeline.vanilla_biometrics_pipeline, npartitions
         )
-        pipeline.biometric_algorithm = pipeline.vanilla_biometrics_pipeline.biometric_algorithm
+        pipeline.biometric_algorithm = (
+            pipeline.vanilla_biometrics_pipeline.biometric_algorithm
+        )
         pipeline.transformer = pipeline.vanilla_biometrics_pipeline.transformer
 
         pipeline.ztnorm_solver = ZTNormDaskWrapper(pipeline.ztnorm_solver)
@@ -267,6 +269,7 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
 
     return pipeline
 
+
 def dask_get_partition_size(cluster, n_objects):
     """
     Heuristics that gives you a number for dask.partition_size.
@@ -287,10 +290,10 @@ def dask_get_partition_size(cluster, n_objects):
         return None
 
     max_jobs = cluster.sge_job_spec["default"]["max_jobs"]
-    return n_objects//max_jobs if n_objects>max_jobs else 1
+    return n_objects // max_jobs if n_objects > max_jobs else 1
 
 
-def checkpoint_vanilla_biometrics(pipeline, base_dir):
+def checkpoint_vanilla_biometrics(pipeline, base_dir, biometric_algorithm_dir=None):
     """
     Given a :any:`VanillaBiometrics`, wraps :any:`VanillaBiometrics.transformer` and
     :any:`VanillaBiometrics.biometric_algorithm` to be checkpointed
@@ -302,7 +305,14 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir):
        Vanilla Biometrics based pipeline to be checkpointed
 
     base_dir: str
-       Path to store biometric references and scores
+       Path to store transformed input data and possibly biometric references and scores
+
+    biometric_algorithm_dir: str
+       If set, it will checkpoint the biometric references and scores to this path.
+       If not, `base_dir` will be used.
+       This is useful when it's suitable to have the transformed data path, and biometric references and scores
+       in different paths.
+
 
     """
 
@@ -340,11 +350,15 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir):
 
         sk_pipeline.steps[i] = (name, wraped_estimator)
 
+    bio_ref_scores_dir = (
+        base_dir if biometric_algorithm_dir is None else biometric_algorithm_dir
+    )
+
     if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy):
-        pipeline.biometric_algorithm.base_dir = base_dir
+        pipeline.biometric_algorithm.base_dir = bio_ref_scores_dir
     else:
         pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper(
-            pipeline.biometric_algorithm, base_dir=base_dir
+            pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir
         )
 
     return pipeline
-- 
GitLab