diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py index 4e7e5b01a651429650341def09c08d0edbe7e4d3..e2d5341e7e599691f53a772bea12b65bdff6066e 100644 --- a/bob/bio/base/script/vanilla_biometrics.py +++ b/bob/bio/base/script/vanilla_biometrics.py @@ -220,7 +220,7 @@ def vanilla_biometrics( "biometric_algorithm", BioAlgorithmCheckpointWrapper, ) - or not isinstance_nested( + and not isinstance_nested( pipeline, "biometric_algorithm", BioAlgorithmLegacy ) ): @@ -236,8 +236,8 @@ def vanilla_biometrics( if dask_client is not None and not isinstance_nested( pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper ): - n_objects = ( - len(background_model_samples) + len(biometric_references) + len(probes) + n_objects = max( + len(background_model_samples), len(biometric_references), len(probes) ) pipeline = dask_vanilla_biometrics( pipeline, diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py index bd849a512993179d56234e90030a292abc82258b..4150e2323ea936942886968c0c6840a2e5c34f33 100644 --- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py +++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py @@ -229,9 +229,10 @@ def vanilla_biometrics_ztnorm( # Patching the pipeline in case of ZNorm and checkpointing it pipeline = ZTNormPipeline(pipeline) - pipeline.ztnorm_solver = ZTNormCheckpointWrapper( - pipeline.ztnorm_solver, os.path.join(output, "normed-scores") - ) + if checkpoint: + pipeline.ztnorm_solver = ZTNormCheckpointWrapper( + pipeline.ztnorm_solver, os.path.join(output, "normed-scores") + ) background_model_samples = database.background_model_samples() zprobes = database.zprobes(proportion=ztnorm_cohort_proportion) @@ -246,8 +247,8 @@ def vanilla_biometrics_ztnorm( if dask_client is not None and not isinstance_nested( pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper ): - n_objects = ( - len(background_model_samples) + len(biometric_references) + len(probes) + n_objects = max( + len(background_model_samples), len(biometric_references), len(probes) ) pipeline = dask_vanilla_biometrics( pipeline,