Commit decb7be9 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'hash_string' into 'master'

Improvements on CheckpointWrapper

See merge request !212
parents 10b5f210 f26369f2
Pipeline #45940 passed with stages
in 12 minutes and 11 seconds
......@@ -88,7 +88,8 @@ def execute_vanilla_biometrics(
# Check if it's already checkpointed
if checkpoint and not is_checkpointed(pipeline):
pipeline = checkpoint_vanilla_biometrics(pipeline, output)
hash_fn = database.hash_fn if hasattr(database, "hash_fn") else None
pipeline = checkpoint_vanilla_biometrics(pipeline, output, hash_fn=hash_fn)
background_model_samples = database.background_model_samples()
......@@ -105,10 +106,8 @@ def execute_vanilla_biometrics(
n_objects = max(
len(background_model_samples), len(biometric_references), len(probes)
)
pipeline = dask_vanilla_biometrics(
pipeline,
partition_size=dask_get_partition_size(dask_client.cluster, n_objects),
)
partition_size = dask_get_partition_size(dask_client.cluster, n_objects)
pipeline = dask_vanilla_biometrics(pipeline, partition_size=partition_size,)
logger.info(f"Running vanilla biometrics for group {group}")
allow_scoring_with_all_biometric_references = (
......
......@@ -311,7 +311,9 @@ def dask_get_partition_size(cluster, n_objects):
return n_objects // (max_jobs * 2) if n_objects > max_jobs else 1
def checkpoint_vanilla_biometrics(pipeline, base_dir, biometric_algorithm_dir=None):
def checkpoint_vanilla_biometrics(
pipeline, base_dir, biometric_algorithm_dir=None, hash_fn=None
):
"""
Given a :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline`, wraps :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` and
:any:`bob.bio.base.pipelines.vanilla_biometrics.BioAlgorithm` to be checkpointed
......@@ -331,7 +333,12 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir, biometric_algorithm_dir=No
This is useful when it's suitable to have the transformed data path, and biometric references and scores
in different paths.
hash_fn
Pointer to a hash function. This hash function will map
`sample.key` to a hash code and this hash code will be the
relative directory where a single `sample` will be checkpointed.
This is useful when is desireable file directories with more than
a certain number of files.
"""
sk_pipeline = pipeline.transformer
......@@ -372,6 +379,7 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir, biometric_algorithm_dir=No
features_dir=os.path.join(base_dir, name),
load_func=load_func,
save_func=save_func,
hash_fn=hash_fn,
)
sk_pipeline.steps[i] = (name, wraped_estimator)
......
......@@ -24,7 +24,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
dask_get_partition_size,
FourColumnsScoreWriter,
CSVScoreWriter,
is_checkpointed
is_checkpointed,
)
from bob.pipelines.utils import isinstance_nested
from .vanilla_biometrics import (
......@@ -228,7 +228,8 @@ def vanilla_biometrics_ztnorm(
# Check if it's already checkpointed
if checkpoint and not is_checkpointed(pipeline):
pipeline = checkpoint_vanilla_biometrics(pipeline, output)
hash_fn = database.hash_fn if hasattr(database, "hash_fn") else None
pipeline = checkpoint_vanilla_biometrics(pipeline, output, hash_fn=hash_fn)
# Patching the pipeline in case of ZNorm and checkpointing it
pipeline = ZTNormPipeline(pipeline)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment