From 8773ff98caab11c649ccbe661c7f4fd6a8b0d2bc Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Tue, 23 Jun 2020 17:23:45 +0200 Subject: [PATCH] Improved the checkpointing --- bob/bio/base/script/vanilla_biometrics.py | 14 +++++++---- .../base/script/vanilla_biometrics_ztnorm.py | 23 +++++++++++++------ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py index 88054034..4e7e5b01 100644 --- a/bob/bio/base/script/vanilla_biometrics.py +++ b/bob/bio/base/script/vanilla_biometrics.py @@ -26,6 +26,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( dask_get_partition_size, FourColumnsScoreWriter, CSVScoreWriter, + BioAlgorithmLegacy ) from dask.delayed import Delayed import pkg_resources @@ -213,10 +214,15 @@ def vanilla_biometrics( pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) # Check if it's already checkpointed - if checkpoint and not isinstance_nested( - pipeline.biometric_algorithm, - "biometric_algorithm", - BioAlgorithmCheckpointWrapper, + if checkpoint and ( + not isinstance_nested( + pipeline, + "biometric_algorithm", + BioAlgorithmCheckpointWrapper, + ) + or not isinstance_nested( + pipeline, "biometric_algorithm", BioAlgorithmLegacy + ) ): pipeline = checkpoint_vanilla_biometrics(pipeline, output) diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py index ce00590a..bd849a51 100644 --- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py +++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py @@ -29,12 +29,17 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( dask_get_partition_size, FourColumnsScoreWriter, CSVScoreWriter, + BioAlgorithmLegacy, ) from dask.delayed import Delayed from bob.bio.base.utils import get_resource_filename from bob.extension.config import load as chain_load from bob.pipelines.utils import isinstance_nested -from .vanilla_biometrics import compute_scores, post_process_scores, load_database_pipeline +from .vanilla_biometrics import ( + compute_scores, + post_process_scores, + load_database_pipeline, +) import copy logger = logging.getLogger(__name__) @@ -67,11 +72,14 @@ EPILOG = """\b entry_point_group="bob.pipelines.config", cls=ConfigCommand, epilog=EPILOG, ) @click.option( - "--pipeline", "-p", required=True, help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm", + "--pipeline", + "-p", + required=True, + help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm", ) @click.option( "--database", - "-d", + "-d", help="Biometric Database connector (class that implements the methods: `background_model_samples`, `references` and `probes`)", ) @click.option( @@ -211,10 +219,11 @@ def vanilla_biometrics_ztnorm( pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) # Check if it's already checkpointed - if checkpoint and not isinstance_nested( - pipeline.biometric_algorithm, - "biometric_algorithm", - BioAlgorithmCheckpointWrapper, + if checkpoint and ( + not isinstance_nested( + pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper, + ) + and not isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy) ): pipeline = checkpoint_vanilla_biometrics(pipeline, output) -- GitLab