diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py index f4c472862a2a9358098b0db18cd189bf565c9187..b577cea0b079791ad086cd01849a39a2c26233ad 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py @@ -240,7 +240,9 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): pipeline.vanilla_biometrics_pipeline = dask_vanilla_biometrics( pipeline.vanilla_biometrics_pipeline, npartitions ) - pipeline.biometric_algorithm = pipeline.vanilla_biometrics_pipeline.biometric_algorithm + pipeline.biometric_algorithm = ( + pipeline.vanilla_biometrics_pipeline.biometric_algorithm + ) pipeline.transformer = pipeline.vanilla_biometrics_pipeline.transformer pipeline.ztnorm_solver = ZTNormDaskWrapper(pipeline.ztnorm_solver) @@ -267,6 +269,7 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): return pipeline + def dask_get_partition_size(cluster, n_objects): """ Heuristics that gives you a number for dask.partition_size. @@ -287,10 +290,10 @@ def dask_get_partition_size(cluster, n_objects): return None max_jobs = cluster.sge_job_spec["default"]["max_jobs"] - return n_objects//max_jobs if n_objects>max_jobs else 1 + return n_objects // max_jobs if n_objects > max_jobs else 1 -def checkpoint_vanilla_biometrics(pipeline, base_dir): +def checkpoint_vanilla_biometrics(pipeline, base_dir, biometric_algorithm_dir=None): """ Given a :any:`VanillaBiometrics`, wraps :any:`VanillaBiometrics.transformer` and :any:`VanillaBiometrics.biometric_algorithm` to be checkpointed @@ -302,7 +305,14 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir): Vanilla Biometrics based pipeline to be checkpointed base_dir: str - Path to store biometric references and scores + Path to store transformed input data and possibly biometric references and scores + + biometric_algorithm_dir: str + If set, it will checkpoint the biometric references and scores to this path. + If not, `base_dir` will be used. + This is useful when it's suitable to have the transformed data path, and biometric references and scores + in different paths. + """ @@ -340,11 +350,15 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir): sk_pipeline.steps[i] = (name, wraped_estimator) + bio_ref_scores_dir = ( + base_dir if biometric_algorithm_dir is None else biometric_algorithm_dir + ) + if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy): - pipeline.biometric_algorithm.base_dir = base_dir + pipeline.biometric_algorithm.base_dir = bio_ref_scores_dir else: pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper( - pipeline.biometric_algorithm, base_dir=base_dir + pipeline.biometric_algorithm, base_dir=bio_ref_scores_dir ) return pipeline