Skip to content
Snippets Groups Projects
Commit 8773ff98 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Improved the checkpointing

parent 7fc43e9e
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #40607 failed
......@@ -26,6 +26,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
dask_get_partition_size,
FourColumnsScoreWriter,
CSVScoreWriter,
BioAlgorithmLegacy
)
from dask.delayed import Delayed
import pkg_resources
......@@ -213,10 +214,15 @@ def vanilla_biometrics(
pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
# Check if it's already checkpointed
if checkpoint and not isinstance_nested(
pipeline.biometric_algorithm,
"biometric_algorithm",
BioAlgorithmCheckpointWrapper,
if checkpoint and (
not isinstance_nested(
pipeline,
"biometric_algorithm",
BioAlgorithmCheckpointWrapper,
)
or not isinstance_nested(
pipeline, "biometric_algorithm", BioAlgorithmLegacy
)
):
pipeline = checkpoint_vanilla_biometrics(pipeline, output)
......
......@@ -29,12 +29,17 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
dask_get_partition_size,
FourColumnsScoreWriter,
CSVScoreWriter,
BioAlgorithmLegacy,
)
from dask.delayed import Delayed
from bob.bio.base.utils import get_resource_filename
from bob.extension.config import load as chain_load
from bob.pipelines.utils import isinstance_nested
from .vanilla_biometrics import compute_scores, post_process_scores, load_database_pipeline
from .vanilla_biometrics import (
compute_scores,
post_process_scores,
load_database_pipeline,
)
import copy
logger = logging.getLogger(__name__)
......@@ -67,11 +72,14 @@ EPILOG = """\b
entry_point_group="bob.pipelines.config", cls=ConfigCommand, epilog=EPILOG,
)
@click.option(
"--pipeline", "-p", required=True, help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
"--pipeline",
"-p",
required=True,
help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
)
@click.option(
"--database",
"-d",
"-d",
help="Biometric Database connector (class that implements the methods: `background_model_samples`, `references` and `probes`)",
)
@click.option(
......@@ -211,10 +219,11 @@ def vanilla_biometrics_ztnorm(
pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
# Check if it's already checkpointed
if checkpoint and not isinstance_nested(
pipeline.biometric_algorithm,
"biometric_algorithm",
BioAlgorithmCheckpointWrapper,
if checkpoint and (
not isinstance_nested(
pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper,
)
and not isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy)
):
pipeline = checkpoint_vanilla_biometrics(pipeline, output)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment