diff --git a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py index 58a34e1639781ff5f65506ec9fcf26ff45643a78..6388efbd69915aa5b48a1a4d6d2f778403c766e3 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py @@ -8,6 +8,7 @@ from .wrappers import ( dask_vanilla_biometrics, checkpoint_vanilla_biometrics, dask_get_partition_size, + is_checkpointed ) from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper, ZTNormCheckpointWrapper diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py index e2d5341e7e599691f53a772bea12b65bdff6066e..b71009ccc0cd5b9c724f89d5aee165b7eee73d9c 100644 --- a/bob/bio/base/script/vanilla_biometrics.py +++ b/bob/bio/base/script/vanilla_biometrics.py @@ -26,7 +26,8 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( dask_get_partition_size, FourColumnsScoreWriter, CSVScoreWriter, - BioAlgorithmLegacy + BioAlgorithmLegacy, + is_checkpointed ) from dask.delayed import Delayed import pkg_resources @@ -214,19 +215,11 @@ def vanilla_biometrics( pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) # Check if it's already checkpointed - if checkpoint and ( - not isinstance_nested( - pipeline, - "biometric_algorithm", - BioAlgorithmCheckpointWrapper, - ) - and not isinstance_nested( - pipeline, "biometric_algorithm", BioAlgorithmLegacy - ) - ): + if checkpoint and not is_checkpointed(pipeline): pipeline = checkpoint_vanilla_biometrics(pipeline, output) background_model_samples = database.background_model_samples() + for group in groups: score_file_name = os.path.join(output, f"scores-{group}") @@ -236,6 +229,7 @@ def vanilla_biometrics( if dask_client is not None and not isinstance_nested( pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper ): + n_objects = max( len(background_model_samples), len(biometric_references), len(probes) ) @@ -258,6 +252,7 @@ def vanilla_biometrics( allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references, ) + post_processed_scores = post_process_scores(pipeline, result, score_file_name) _ = compute_scores(post_processed_scores, dask_client) diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py index 4150e2323ea936942886968c0c6840a2e5c34f33..0cc11b397aff47792847dbc6eabc7bf256ed14ee 100644 --- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py +++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py @@ -30,6 +30,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( FourColumnsScoreWriter, CSVScoreWriter, BioAlgorithmLegacy, + is_checkpointed ) from dask.delayed import Delayed from bob.bio.base.utils import get_resource_filename @@ -219,12 +220,7 @@ def vanilla_biometrics_ztnorm( pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) # Check if it's already checkpointed - if checkpoint and ( - not isinstance_nested( - pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper, - ) - and not isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy) - ): + if checkpoint and not is_checkpointed(pipeline): pipeline = checkpoint_vanilla_biometrics(pipeline, output) # Patching the pipeline in case of ZNorm and checkpointing it