Skip to content
Snippets Groups Projects
Commit 4b238a25 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Comparison app

parent 9b7a8633
No related branches found
No related tags found
2 merge requests!192Redoing baselines,!180[dask] Preparing bob.bio.base for dask pipelines
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
"""Executes biometric pipeline"""
import click
from bob.extension.scripts.click_helper import (
verbosity_option,
ResourceOption,
ConfigCommand,
)
import bob.io.base
import bob.io.image
import logging
import os
import itertools
import dask.bag
# from bob.bio.base.pipelines.vanilla_biometrics import (
# VanillaBiometricsPipeline,
# BioAlgorithmCheckpointWrapper,
# BioAlgorithmDaskWrapper,
# checkpoint_vanilla_biometrics,
# dask_vanilla_biometrics,
# dask_get_partition_size,
# FourColumnsScoreWriter,
# CSVScoreWriter,
# )
# from dask.delayed import Delayed
# import pkg_resources
from bob.extension.config import load as chain_load
from bob.pipelines.utils import isinstance_nested
from bob.bio.base.utils import get_resource_filename
from .vanilla_biometrics import compute_scores, load_database_pipeline
from bob.pipelines import Sample, SampleSet
logger = logging.getLogger(__name__)
EPILOG = """\b
Command line examples\n
-----------------------
"""
@click.command(epilog=EPILOG)
@click.argument("samples", nargs=-1)
@click.option(
"--pipeline",
"-p",
required=True,
cls=ResourceOption,
entry_point_group="bob.bio.pipeline",
help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
)
@click.option(
"--dask-client",
"-l",
required=False,
cls=ResourceOption,
help="Dask client for the execution of the pipeline.",
)
@verbosity_option(cls=ResourceOption)
def compare_samples(
samples, pipeline, dask_client, **kwargs,
):
"""Compare several samples all vs all using one vanilla biometrics pipeline
"""
if len(samples) == 1:
raise ValueError(
"It's necessary to have at least two samples for the comparison"
)
sample_sets = [
SampleSet([Sample(bob.io.base.load(s), key=str(i))], key=str(i))
for i, s in enumerate(samples)
]
import ipdb; ipdb.set_trace()
for e in sample_sets:
biometric_references = pipeline.create_biometric_reference([e])
scores = pipeline.compute_scores(biometric_references, sample_sets)
pass
# B = bob.io.base.load(p)
# pipeline.biometric_algorithm
if dask_client is not None:
dask_client.shutdown()
......@@ -74,16 +74,32 @@ def post_process_scores(pipeline, scores, path):
return pipeline.post_process(writed_scores, path)
def load_database_pipeline(database, pipeline):
# It's necessary to chain load 2 resources together
pipeline_config = get_resource_filename(pipeline, "bob.bio.pipeline")
if database is None:
vanilla_pipeline = chain_load([pipeline_config])
if hasattr(vanilla_pipeline, "database"):
return vanilla_pipeline.database, vanilla_pipeline.pipeline
else:
raise ValueError("Database was not set. Please look in `bob bio pipelines vanilla-biometrics --help` for more information")
else:
database_config = get_resource_filename(database, "bob.bio.database")
vanilla_pipeline = chain_load([database_config, pipeline_config])
return vanilla_pipeline.database, vanilla_pipeline.pipeline
@click.command(
entry_point_group="bob.bio.pipeline.config", cls=ConfigCommand, epilog=EPILOG,
)
@click.option(
"--pipeline", "-p", required=True, help="Vanilla biometrics pipeline",
"--pipeline", "-p", required=True, help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
)
@click.option(
"--database",
"-d",
required=True,
required=False,
help="Biometric Database connector (class that implements the methods: `background_model_samples`, `references` and `probes`)",
)
@click.option(
......@@ -185,15 +201,12 @@ def vanilla_biometrics(
if not os.path.exists(output):
os.makedirs(output, exist_ok=True)
# It's necessary to chain load 2 resources together
pipeline_config = get_resource_filename(pipeline, "bob.bio.pipeline")
database_config = get_resource_filename(database, "bob.bio.database")
vanilla_pipeline = chain_load([database_config, pipeline_config])
dask_client = chain_load([dask_client]).dask_client
# Picking the resources
database = vanilla_pipeline.database
pipeline = vanilla_pipeline.pipeline
database, pipeline = load_database_pipeline(database, pipeline)
if dask_client is not None:
dask_client = chain_load([dask_client]).dask_client
if write_metadata_scores:
pipeline.score_writer = CSVScoreWriter(os.path.join(output, "./tmp"))
else:
......
......@@ -34,7 +34,7 @@ from dask.delayed import Delayed
from bob.bio.base.utils import get_resource_filename
from bob.extension.config import load as chain_load
from bob.pipelines.utils import isinstance_nested
from .vanilla_biometrics import compute_scores, post_process_scores
from .vanilla_biometrics import compute_scores, post_process_scores, load_database_pipeline
import copy
logger = logging.getLogger(__name__)
......@@ -67,7 +67,7 @@ EPILOG = """\b
entry_point_group="bob.pipelines.config", cls=ConfigCommand, epilog=EPILOG,
)
@click.option(
"--pipeline", "-p", required=True, help="An entry point or a configuration file containing a `VanillaBiometricsPipeline`.",
"--pipeline", "-p", required=True, help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
)
@click.option(
"--database",
......@@ -199,16 +199,12 @@ def vanilla_biometrics_ztnorm(
os.makedirs(output, exist_ok=True)
# It's necessary to chain load 2 resources together
pipeline_config = get_resource_filename(pipeline, "bob.bio.pipeline")
database_config = get_resource_filename(database, "bob.bio.database")
vanilla_pipeline = chain_load([database_config, pipeline_config])
# Picking the resources
database, pipeline = load_database_pipeline(database, pipeline)
if dask_client is not None:
dask_client = chain_load([dask_client]).dask_client
# Picking the resources
database = vanilla_pipeline.database
pipeline = vanilla_pipeline.pipeline
if write_metadata_scores:
pipeline.score_writer = CSVScoreWriter(os.path.join(output, "./tmp"))
else:
......
......@@ -133,6 +133,7 @@ setup(
'baseline = bob.bio.base.script.baseline:baseline',
'sort = bob.bio.base.script.sort:sort',
'pipelines = bob.bio.base.script.pipelines:pipelines',
'compare-samples = bob.bio.base.script.compare_samples:compare_samples',
],
# annotators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment