From ebc2cb4c0e70f579ef5f9c50419b7177c92197eb Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Fri, 1 May 2020 19:37:41 +0200 Subject: [PATCH] Patched with new VanillaBiometricsPipeline --- bob/bio/base/script/vanilla_biometrics.py | 77 ++++++++++++----------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py index ddcd841d..b5d76595 100644 --- a/bob/bio/base/script/vanilla_biometrics.py +++ b/bob/bio/base/script/vanilla_biometrics.py @@ -14,6 +14,14 @@ from bob.extension.scripts.click_helper import ( ) import logging +import os +import itertools +import dask.bag +from bob.bio.base.pipelines.vanilla_biometrics import ( + VanillaBiometricsPipeline, + BioAlgorithmCheckpointWrapper, +) + logger = logging.getLogger(__name__) @@ -96,9 +104,7 @@ TODO: Work out this help help="Name of output directory", ) @verbosity_option(cls=ResourceOption) -def vanilla_biometrics( - pipeline, database, dask_client, groups, output, **kwargs -): +def vanilla_biometrics(pipeline, database, dask_client, groups, output, **kwargs): """Runs the simplest biometrics pipeline. Such pipeline consists into three sub-pipelines. @@ -143,43 +149,42 @@ def vanilla_biometrics( """ - from bob.bio.base.pipelines.vanilla_biometrics.pipeline import VanillaBiometrics - import dask.bag - import itertools - import os - from bob.pipelines.sample import Sample, DelayedSample - if not os.path.exists(output): - os.makedirs(output, exist_ok=True) + os.makedirs(output, exist_ok=True) for group in groups: - with open(os.path.join(output, f"scores-{group}"), "w") as f: - biometric_references = database.references(group=group) - - logger.info(f"Running vanilla biometrics for group {group}") - - allow_scoring_with_all_biometric_references = ( - database.allow_scoring_with_all_biometric_references - if hasattr(database, "allow_scoring_with_all_biometric_references") - else False - ) - - result = pipeline(database.background_model_samples(), - biometric_references, - database.probes(group=group), - allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references - ) - - if isinstance(result, dask.bag.core.Bag): - if dask_client is not None: - result = result.compute(scheduler=dask_client) - else: - logger.warning( - "`dask_client` not set. Your pipeline will run locally" - ) - result = result.compute(scheduler="single-threaded") - + score_file_name = os.path.join(output, f"scores-{group}.txt") + biometric_references = database.references(group=group) + + logger.info(f"Running vanilla biometrics for group {group}") + + allow_scoring_with_all_biometric_references = ( + database.allow_scoring_with_all_biometric_references + if hasattr(database, "allow_scoring_with_all_biometric_references") + else False + ) + + result = pipeline( + database.background_model_samples(), + biometric_references, + database.probes(group=group), + allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references, + ) + + if isinstance(result, dask.bag.core.Bag): + if dask_client is not None: + result = result.compute(scheduler=dask_client) + else: + logger.warning( + "`dask_client` not set. Your pipeline will run locally" + ) + result = result.compute(scheduler="single-threaded") + + # Check if there's a score writer hooked in + if isinstance(pipeline.biometric_algorithm, BioAlgorithmCheckpointWrapper): + pipeline.biometric_algorithm.score_writer.concatenate_write_scores(result, score_file_name) + else: # Flatting out the list result = itertools.chain(*result) for probe in result: -- GitLab