Skip to content
Snippets Groups Projects
Commit 3797c15f authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented a mechanism to get the partition_size

parent de050b46
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
...@@ -220,7 +220,7 @@ def vanilla_biometrics( ...@@ -220,7 +220,7 @@ def vanilla_biometrics(
"biometric_algorithm", "biometric_algorithm",
BioAlgorithmCheckpointWrapper, BioAlgorithmCheckpointWrapper,
) )
or not isinstance_nested( and not isinstance_nested(
pipeline, "biometric_algorithm", BioAlgorithmLegacy pipeline, "biometric_algorithm", BioAlgorithmLegacy
) )
): ):
...@@ -236,8 +236,8 @@ def vanilla_biometrics( ...@@ -236,8 +236,8 @@ def vanilla_biometrics(
if dask_client is not None and not isinstance_nested( if dask_client is not None and not isinstance_nested(
pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
): ):
n_objects = ( n_objects = max(
len(background_model_samples) + len(biometric_references) + len(probes) len(background_model_samples), len(biometric_references), len(probes)
) )
pipeline = dask_vanilla_biometrics( pipeline = dask_vanilla_biometrics(
pipeline, pipeline,
......
...@@ -229,9 +229,10 @@ def vanilla_biometrics_ztnorm( ...@@ -229,9 +229,10 @@ def vanilla_biometrics_ztnorm(
# Patching the pipeline in case of ZNorm and checkpointing it # Patching the pipeline in case of ZNorm and checkpointing it
pipeline = ZTNormPipeline(pipeline) pipeline = ZTNormPipeline(pipeline)
pipeline.ztnorm_solver = ZTNormCheckpointWrapper( if checkpoint:
pipeline.ztnorm_solver, os.path.join(output, "normed-scores") pipeline.ztnorm_solver = ZTNormCheckpointWrapper(
) pipeline.ztnorm_solver, os.path.join(output, "normed-scores")
)
background_model_samples = database.background_model_samples() background_model_samples = database.background_model_samples()
zprobes = database.zprobes(proportion=ztnorm_cohort_proportion) zprobes = database.zprobes(proportion=ztnorm_cohort_proportion)
...@@ -246,8 +247,8 @@ def vanilla_biometrics_ztnorm( ...@@ -246,8 +247,8 @@ def vanilla_biometrics_ztnorm(
if dask_client is not None and not isinstance_nested( if dask_client is not None and not isinstance_nested(
pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
): ):
n_objects = ( n_objects = max(
len(background_model_samples) + len(biometric_references) + len(probes) len(background_model_samples), len(biometric_references), len(probes)
) )
pipeline = dask_vanilla_biometrics( pipeline = dask_vanilla_biometrics(
pipeline, pipeline,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment