Skip to content
Snippets Groups Projects
Commit ebc2cb4c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Patched with new VanillaBiometricsPipeline

parent ba36a5cf
No related branches found
No related tags found
2 merge requests!185Wrappers and aggregators,!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #39605 failed
...@@ -14,6 +14,14 @@ from bob.extension.scripts.click_helper import ( ...@@ -14,6 +14,14 @@ from bob.extension.scripts.click_helper import (
) )
import logging import logging
import os
import itertools
import dask.bag
from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline,
BioAlgorithmCheckpointWrapper,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -96,9 +104,7 @@ TODO: Work out this help ...@@ -96,9 +104,7 @@ TODO: Work out this help
help="Name of output directory", help="Name of output directory",
) )
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def vanilla_biometrics( def vanilla_biometrics(pipeline, database, dask_client, groups, output, **kwargs):
pipeline, database, dask_client, groups, output, **kwargs
):
"""Runs the simplest biometrics pipeline. """Runs the simplest biometrics pipeline.
Such pipeline consists into three sub-pipelines. Such pipeline consists into three sub-pipelines.
...@@ -143,43 +149,42 @@ def vanilla_biometrics( ...@@ -143,43 +149,42 @@ def vanilla_biometrics(
""" """
from bob.bio.base.pipelines.vanilla_biometrics.pipeline import VanillaBiometrics
import dask.bag
import itertools
import os
from bob.pipelines.sample import Sample, DelayedSample
if not os.path.exists(output): if not os.path.exists(output):
os.makedirs(output, exist_ok=True) os.makedirs(output, exist_ok=True)
for group in groups: for group in groups:
with open(os.path.join(output, f"scores-{group}"), "w") as f: score_file_name = os.path.join(output, f"scores-{group}.txt")
biometric_references = database.references(group=group) biometric_references = database.references(group=group)
logger.info(f"Running vanilla biometrics for group {group}") logger.info(f"Running vanilla biometrics for group {group}")
allow_scoring_with_all_biometric_references = ( allow_scoring_with_all_biometric_references = (
database.allow_scoring_with_all_biometric_references database.allow_scoring_with_all_biometric_references
if hasattr(database, "allow_scoring_with_all_biometric_references") if hasattr(database, "allow_scoring_with_all_biometric_references")
else False else False
) )
result = pipeline(database.background_model_samples(), result = pipeline(
biometric_references, database.background_model_samples(),
database.probes(group=group), biometric_references,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references database.probes(group=group),
) allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
if isinstance(result, dask.bag.core.Bag):
if dask_client is not None: if isinstance(result, dask.bag.core.Bag):
result = result.compute(scheduler=dask_client) if dask_client is not None:
else: result = result.compute(scheduler=dask_client)
logger.warning( else:
"`dask_client` not set. Your pipeline will run locally" logger.warning(
) "`dask_client` not set. Your pipeline will run locally"
result = result.compute(scheduler="single-threaded") )
result = result.compute(scheduler="single-threaded")
# Check if there's a score writer hooked in
if isinstance(pipeline.biometric_algorithm, BioAlgorithmCheckpointWrapper):
pipeline.biometric_algorithm.score_writer.concatenate_write_scores(result, score_file_name)
else:
# Flatting out the list # Flatting out the list
result = itertools.chain(*result) result = itertools.chain(*result)
for probe in result: for probe in result:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment