From 45b7f66c73fbc3db6089d87a6a243c18479e58b4 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Mon, 27 Jul 2020 13:29:05 +0200 Subject: [PATCH] Created the function is_checkpointed --- .../pipelines/vanilla_biometrics/__init__.py | 1 + bob/bio/base/script/vanilla_biometrics.py | 17 ++++++----------- .../base/script/vanilla_biometrics_ztnorm.py | 8 ++------ 3 files changed, 9 insertions(+), 17 deletions(-) diff --git a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py index 58a34e16..6388efbd 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py @@ -8,6 +8,7 @@ from .wrappers import ( dask_vanilla_biometrics, checkpoint_vanilla_biometrics, dask_get_partition_size, + is_checkpointed ) from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper, ZTNormCheckpointWrapper diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py index e2d5341e..b71009cc 100644 --- a/bob/bio/base/script/vanilla_biometrics.py +++ b/bob/bio/base/script/vanilla_biometrics.py @@ -26,7 +26,8 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( dask_get_partition_size, FourColumnsScoreWriter, CSVScoreWriter, - BioAlgorithmLegacy + BioAlgorithmLegacy, + is_checkpointed ) from dask.delayed import Delayed import pkg_resources @@ -214,19 +215,11 @@ def vanilla_biometrics( pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) # Check if it's already checkpointed - if checkpoint and ( - not isinstance_nested( - pipeline, - "biometric_algorithm", - BioAlgorithmCheckpointWrapper, - ) - and not isinstance_nested( - pipeline, "biometric_algorithm", BioAlgorithmLegacy - ) - ): + if checkpoint and not is_checkpointed(pipeline): pipeline = checkpoint_vanilla_biometrics(pipeline, output) background_model_samples = database.background_model_samples() + for group in groups: score_file_name = os.path.join(output, f"scores-{group}") @@ -236,6 +229,7 @@ def vanilla_biometrics( if dask_client is not None and not isinstance_nested( pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper ): + n_objects = max( len(background_model_samples), len(biometric_references), len(probes) ) @@ -258,6 +252,7 @@ def vanilla_biometrics( allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references, ) + post_processed_scores = post_process_scores(pipeline, result, score_file_name) _ = compute_scores(post_processed_scores, dask_client) diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py index 4150e232..0cc11b39 100644 --- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py +++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py @@ -30,6 +30,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( FourColumnsScoreWriter, CSVScoreWriter, BioAlgorithmLegacy, + is_checkpointed ) from dask.delayed import Delayed from bob.bio.base.utils import get_resource_filename @@ -219,12 +220,7 @@ def vanilla_biometrics_ztnorm( pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) # Check if it's already checkpointed - if checkpoint and ( - not isinstance_nested( - pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper, - ) - and not isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy) - ): + if checkpoint and not is_checkpointed(pipeline): pipeline = checkpoint_vanilla_biometrics(pipeline, output) # Patching the pipeline in case of ZNorm and checkpointing it -- GitLab