Skip to content
Snippets Groups Projects
Commit 45b7f66c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created the function is_checkpointed

parent 4895957a
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #41504 failed
...@@ -8,6 +8,7 @@ from .wrappers import ( ...@@ -8,6 +8,7 @@ from .wrappers import (
dask_vanilla_biometrics, dask_vanilla_biometrics,
checkpoint_vanilla_biometrics, checkpoint_vanilla_biometrics,
dask_get_partition_size, dask_get_partition_size,
is_checkpointed
) )
from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper, ZTNormCheckpointWrapper from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper, ZTNormCheckpointWrapper
......
...@@ -26,7 +26,8 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -26,7 +26,8 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
dask_get_partition_size, dask_get_partition_size,
FourColumnsScoreWriter, FourColumnsScoreWriter,
CSVScoreWriter, CSVScoreWriter,
BioAlgorithmLegacy BioAlgorithmLegacy,
is_checkpointed
) )
from dask.delayed import Delayed from dask.delayed import Delayed
import pkg_resources import pkg_resources
...@@ -214,19 +215,11 @@ def vanilla_biometrics( ...@@ -214,19 +215,11 @@ def vanilla_biometrics(
pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
# Check if it's already checkpointed # Check if it's already checkpointed
if checkpoint and ( if checkpoint and not is_checkpointed(pipeline):
not isinstance_nested(
pipeline,
"biometric_algorithm",
BioAlgorithmCheckpointWrapper,
)
and not isinstance_nested(
pipeline, "biometric_algorithm", BioAlgorithmLegacy
)
):
pipeline = checkpoint_vanilla_biometrics(pipeline, output) pipeline = checkpoint_vanilla_biometrics(pipeline, output)
background_model_samples = database.background_model_samples() background_model_samples = database.background_model_samples()
for group in groups: for group in groups:
score_file_name = os.path.join(output, f"scores-{group}") score_file_name = os.path.join(output, f"scores-{group}")
...@@ -236,6 +229,7 @@ def vanilla_biometrics( ...@@ -236,6 +229,7 @@ def vanilla_biometrics(
if dask_client is not None and not isinstance_nested( if dask_client is not None and not isinstance_nested(
pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
): ):
n_objects = max( n_objects = max(
len(background_model_samples), len(biometric_references), len(probes) len(background_model_samples), len(biometric_references), len(probes)
) )
...@@ -258,6 +252,7 @@ def vanilla_biometrics( ...@@ -258,6 +252,7 @@ def vanilla_biometrics(
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references, allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
) )
post_processed_scores = post_process_scores(pipeline, result, score_file_name) post_processed_scores = post_process_scores(pipeline, result, score_file_name)
_ = compute_scores(post_processed_scores, dask_client) _ = compute_scores(post_processed_scores, dask_client)
......
...@@ -30,6 +30,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -30,6 +30,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
FourColumnsScoreWriter, FourColumnsScoreWriter,
CSVScoreWriter, CSVScoreWriter,
BioAlgorithmLegacy, BioAlgorithmLegacy,
is_checkpointed
) )
from dask.delayed import Delayed from dask.delayed import Delayed
from bob.bio.base.utils import get_resource_filename from bob.bio.base.utils import get_resource_filename
...@@ -219,12 +220,7 @@ def vanilla_biometrics_ztnorm( ...@@ -219,12 +220,7 @@ def vanilla_biometrics_ztnorm(
pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp")) pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
# Check if it's already checkpointed # Check if it's already checkpointed
if checkpoint and ( if checkpoint and not is_checkpointed(pipeline):
not isinstance_nested(
pipeline, "biometric_algorithm", BioAlgorithmCheckpointWrapper,
)
and not isinstance_nested(pipeline, "biometric_algorithm", BioAlgorithmLegacy)
):
pipeline = checkpoint_vanilla_biometrics(pipeline, output) pipeline = checkpoint_vanilla_biometrics(pipeline, output)
# Patching the pipeline in case of ZNorm and checkpointing it # Patching the pipeline in case of ZNorm and checkpointing it
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment