Skip to content
Snippets Groups Projects
Commit 21eed431 authored by Laurent COLBOIS's avatar Laurent COLBOIS
Browse files

Refactor and add script for cross database evaluation

parent 4ed66d1b
Branches cross-db
No related tags found
No related merge requests found
Pipeline #54967 passed with stage
in 15 minutes and 49 seconds
......@@ -32,6 +32,22 @@ def post_process_scores(pipeline, scores, path):
return pipeline.post_process(written_scores, path)
def setup_writing(pipeline, output, write_metadata_scores, checkpoint, hash_fn):
if not os.path.exists(output):
os.makedirs(output, exist_ok=True)
if write_metadata_scores:
pipeline.score_writer = CSVScoreWriter(os.path.join(output, "./tmp"))
else:
pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
# Check if it's already checkpointed
if checkpoint and not is_checkpointed(pipeline):
pipeline = checkpoint_vanilla_biometrics(pipeline, output, hash_fn=hash_fn)
return pipeline
def execute_vanilla_biometrics(
pipeline,
database,
......@@ -80,18 +96,13 @@ def execute_vanilla_biometrics(
checkpoint: bool
Whether checkpoint files will be created for every step of the pipelines.
"""
if not os.path.exists(output):
os.makedirs(output, exist_ok=True)
if write_metadata_scores:
pipeline.score_writer = CSVScoreWriter(os.path.join(output, "./tmp"))
else:
pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
# Check if it's already checkpointed
if checkpoint and not is_checkpointed(pipeline):
hash_fn = database.hash_fn if hasattr(database, "hash_fn") else None
pipeline = checkpoint_vanilla_biometrics(pipeline, output, hash_fn=hash_fn)
pipeline = setup_writing(
pipeline,
output,
write_metadata_scores,
checkpoint,
hash_fn=database.hash_fn if hasattr(database, "hash_fn") else None,
)
# Load the background model samples only if the transformer requires fitting
if all([is_estimator_stateless(step) for step in pipeline.transformer]):
......@@ -114,24 +125,6 @@ def execute_vanilla_biometrics(
)
continue
if dask_client is not None and not isinstance_nested(
pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
):
# Scaling up
if dask_n_workers is not None and not isinstance(dask_client, str):
dask_client.cluster.scale(dask_n_workers)
n_objects = max(
len(background_model_samples), len(biometric_references), len(probes)
)
partition_size = None
if not isinstance(dask_client, str):
partition_size = dask_get_partition_size(dask_client.cluster, n_objects)
if dask_partition_size is not None:
partition_size = dask_partition_size
pipeline = dask_vanilla_biometrics(pipeline, partition_size=partition_size,)
logger.info(f"Running vanilla biometrics for group {group}")
allow_scoring_with_all_biometric_references = (
database.allow_scoring_with_all_biometric_references
......@@ -139,15 +132,159 @@ def execute_vanilla_biometrics(
else False
)
result = pipeline(
execute_single_vanilla_biometrics(
pipeline,
background_model_samples,
biometric_references,
probes,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
dask_client,
score_file_name,
allow_scoring_with_all_biometric_references,
dask_partition_size,
dask_n_workers,
**kwargs,
)
post_processed_scores = post_process_scores(pipeline, result, score_file_name)
_ = compute_scores(post_processed_scores, dask_client)
def execute_cross_db_vanilla_biometrics(
pipeline,
biometric_references,
probes,
dask_client,
output,
score_file_name,
write_metadata_scores,
checkpoint,
allow_scoring_with_all_biometric_references,
dask_partition_size,
dask_n_workers,
background_model_samples=[],
hash_fn=None,
**kwargs,
):
"""
Function that executes a Vanilla Biometric evaluation pipeline. Instead working with a :py:class: `bob.bio.base.pipelines.vanilla_biometrics.abstract_class.Database`,
this function takes independent `background_model_samples`, `biometric_references` and `probes`. Thus it enables to perform cross-database experiments.
Parameters
----------
pipeline: Instance of :py:class:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline`
A constructed vanilla-biometrics pipeline. The pipeline should already have been setup for writing using the
:py:function: `bob.bio.base.pipelines.vanilla_biometrics.vanilla_biometrics.setup_writing` function
background_model_samples: List of :py:class:`bob.pipelines.samples.SampleSet` used for training the model.
(Such as the output of a :py:method: `bob.bio.base.pipelines.vanilla_biometrics.abstract_class.Database.background_model_samples()`)
biometric_references: List of :py:class:`bob.pipelines.samples.SampleSet` used for enrolling the biometric references.
(Such as the output of a :py:method: `bob.bio.base.pipelines.vanilla_biometrics.abstract_class.Database.references()`)
probes: List of :py:class:`bob.pipelines.samples.SampleSet` used to probe the model.
(Such as the output of a :py:method: `bob.bio.base.pipelines.vanilla_biometrics.abstract_class.Database.probes()`)
dask_client: instance of :py:class:`dask.distributed.Client` or ``None``
A Dask client instance used to run the experiment in parallel on multiple
machines, or locally.
Basic configs can be found in ``bob.pipelines.config.distributed``.
output: str
Path where the results and checkpoints will be saved to.
score_file_name: str
Name of the ouput CSV scores file
write_metadata_scores: bool
Use the CSVScoreWriter instead of the FourColumnScoreWriter when True.
checkpoint: bool
Whether checkpoint files will be created for every step of the pipelines.
"""
pipeline = setup_writing(
pipeline, output, write_metadata_scores, checkpoint, hash_fn
)
execute_single_vanilla_biometrics(
pipeline,
background_model_samples,
biometric_references,
probes,
dask_client,
score_file_name,
allow_scoring_with_all_biometric_references,
dask_partition_size,
dask_n_workers,
**kwargs,
)
def execute_single_vanilla_biometrics(
pipeline,
background_model_samples,
biometric_references,
probes,
dask_client,
score_file_name,
allow_scoring_with_all_biometric_references,
dask_partition_size,
dask_n_workers,
**kwargs,
):
"""
Utilitary function that executes a single Vanilla Biometrics evaluation pipeline.
Parameters
----------
pipeline: Instance of :py:class:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline`
A constructed vanilla-biometrics pipeline. The pipeline should already have been setup for writing using the
:py:function: `bob.bio.base.pipelines.vanillca_biometrics.vanilla_biometrics.setup_writing` function
background_model_samples: List of :py:class:`bob.pipelines.samples.SampleSet` used for training the model.
(Such as the output of a :py:method: `bob.bio.base.pipelines.vanilla_biometrics.abstract_class.Database.background_model_samples()`)
biometric_references: List of :py:class:`bob.pipelines.samples.SampleSet` used for enrolling the biometric references.
(Such as the output of a :py:method: `bob.bio.base.pipelines.vanilla_biometrics.abstract_class.Database.references()`)
probes: List of :py:class:`bob.pipelines.samples.SampleSet` used to probe the model.
(Such as the output of a :py:method: `bob.bio.base.pipelines.vanilla_biometrics.abstract_class.Database.probes()`)
dask_client: instance of :py:class:`dask.distributed.Client` or ``None``
A Dask client instance used to run the experiment in parallel on multiple
machines, or locally.
Basic configs can be found in ``bob.pipelines.config.distributed``.
score_file_name: str
Name of the ouput CSV scores file
"""
if dask_client is not None and not isinstance_nested(
pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
):
# Scaling up
if dask_n_workers is not None and not isinstance(dask_client, str):
dask_client.cluster.scale(dask_n_workers)
n_objects = max(
len(background_model_samples), len(biometric_references), len(probes)
)
partition_size = None
if not isinstance(dask_client, str):
partition_size = dask_get_partition_size(dask_client.cluster, n_objects)
if dask_partition_size is not None:
partition_size = dask_partition_size
pipeline = dask_vanilla_biometrics(pipeline, partition_size=partition_size,)
result = pipeline(
background_model_samples,
biometric_references,
probes,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
post_processed_scores = post_process_scores(pipeline, result, score_file_name)
_ = compute_scores(post_processed_scores, dask_client)
def execute_vanilla_biometrics_ztnorm(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment