diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py index 880540342bc78f9d8e1a02b2a85168d83b251d7d..4e7e5b01a651429650341def09c08d0edbe7e4d3 100644 --- a/bob/bio/base/script/vanilla_biometrics.py +++ b/bob/bio/base/script/vanilla_biometrics.py @@ -26,6 +26,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( dask_get_partition_size, FourColumnsScoreWriter, CSVScoreWriter, + BioAlgorithmLegacy ) from dask.delayed import Delayed import pkg_resources @@ -213,10 +214,15 @@ def vanilla_biometrics( pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) # Check if it's already checkpointed - if checkpoint and not isinstance_nested( - pipeline.biometric_algorithm, - "biometric_algorithm", - BioAlgorithmCheckpointWrapper, + if checkpoint and ( + not isinstance_nested( + pipeline, + "biometric_algorithm", + BioAlgorithmCheckpointWrapper, + ) + or not isinstance_nested( + pipeline, "biometric_algorithm", BioAlgorithmLegacy + ) ): pipeline = checkpoint_vanilla_biometrics(pipeline, output) diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py index ce00590a332d33b0e75b65b14b0cb76843891353..bd849a512993179d56234e90030a292abc82258b 100644 --- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py +++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py @@ -29,12 +29,17 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( dask_get_partition_size, FourColumnsScoreWriter, CSVScoreWriter, + BioAlgorithmLegacy, ) from dask.delayed import Delayed from bob.bio.base.utils import get_resource_filename from bob.extension.config import load as chain_load from bob.pipelines.utils import isinstance_nested -from .vanilla_biometrics import compute_scores, post_process_scores, load_database_pipeline +from .vanilla_biometrics import ( + compute_scores, + post_process_scores, + load_database_pipeline, +) import copy logger = logging.getLogger(__name__) @@ -67,11 +72,14 @@ EPILOG = """\b entry_point_group="bob.pipelines.config", cls=ConfigCommand, epilog=EPILOG, ) @click.option( - "--pipeline", "-p", required=True, help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm", + "--pipeline", + "-p", + required=True, + help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm", ) @click.option( "--database", - "-d", + "-d", help="Biometric Database connector (class that implements the methods: `background_model_samples`, `references` and `probes`)", ) @click.option( @@ -211,10 +219,11 @@ def vanilla_biometrics_ztnorm( pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) # Check if it's already checkpointed - if checkpoint and not isinstance_nested( - pipeline.biometric_algorithm, - "biometric_algorithm", - BioAlgorithmCheckpointWrapper, + if checkpoint and ( + not isinstance_nested( + pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper, + ) + and not isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy) ): pipeline = checkpoint_vanilla_biometrics(pipeline, output)