Skip to content
Snippets Groups Projects
Commit c644b503 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Comparison app. WIP

parent 05a86fa3
No related branches found
No related tags found
No related merge requests found
Pipeline #40566 passed
...@@ -19,7 +19,8 @@ import logging ...@@ -19,7 +19,8 @@ import logging
import os import os
import itertools import itertools
import dask.bag import dask.bag
#from bob.bio.base.pipelines.vanilla_biometrics import (
# from bob.bio.base.pipelines.vanilla_biometrics import (
# VanillaBiometricsPipeline, # VanillaBiometricsPipeline,
# BioAlgorithmCheckpointWrapper, # BioAlgorithmCheckpointWrapper,
# BioAlgorithmDaskWrapper, # BioAlgorithmDaskWrapper,
...@@ -28,13 +29,14 @@ import dask.bag ...@@ -28,13 +29,14 @@ import dask.bag
# dask_get_partition_size, # dask_get_partition_size,
# FourColumnsScoreWriter, # FourColumnsScoreWriter,
# CSVScoreWriter, # CSVScoreWriter,
#) # )
#from dask.delayed import Delayed # from dask.delayed import Delayed
#import pkg_resources # import pkg_resources
from bob.extension.config import load as chain_load from bob.extension.config import load as chain_load
from bob.pipelines.utils import isinstance_nested from bob.pipelines.utils import isinstance_nested
from bob.bio.base.utils import get_resource_filename from bob.bio.base.utils import get_resource_filename
from .vanilla_biometrics import compute_scores, load_database_pipeline from .vanilla_biometrics import compute_scores, load_database_pipeline
from bob.pipelines import Sample, SampleSet
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,16 +52,15 @@ EPILOG = """\b ...@@ -50,16 +52,15 @@ EPILOG = """\b
""" """
@click.command( @click.command(epilog=EPILOG)
entry_point_group="bob.bio.pipeline", cls=ConfigCommand, epilog=EPILOG, @click.argument("samples", nargs=-1)
)
@click.option( @click.option(
"--pipeline", "--pipeline",
"-p", "-p",
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
entry_point_group="bob.bio.pipeline", entry_point_group="bob.bio.pipeline",
help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm" help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
) )
@click.option( @click.option(
"--dask-client", "--dask-client",
...@@ -68,32 +69,32 @@ EPILOG = """\b ...@@ -68,32 +69,32 @@ EPILOG = """\b
cls=ResourceOption, cls=ResourceOption,
help="Dask client for the execution of the pipeline.", help="Dask client for the execution of the pipeline.",
) )
@click.option(
"--samples", "-s", required=True, help="Vanilla biometrics pipeline", multiple=True,
)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def compare_samples( def compare_samples(
pipeline, samples, pipeline, dask_client, **kwargs,
dask_client,
samples,
**kwargs,
): ):
"""Compare several samples all vs all using one vanilla biometrics pipeline """Compare several samples all vs all using one vanilla biometrics pipeline
""" """
if len(samples) == 1: if len(samples) == 1:
raise ValueError("It's necessary to have at least two samples for the comparison") raise ValueError(
"It's necessary to have at least two samples for the comparison"
#for e in samples: )
#A = bob.io.base.load(e)
#for p in samples: sample_sets = [
#B = bob.io.base.load(p) SampleSet([Sample(bob.io.base.load(s), key=str(i))], key=str(i))
for i, s in enumerate(samples)
]
import ipdb; ipdb.set_trace()
for e in sample_sets:
biometric_references = pipeline.create_biometric_reference([e])
scores = pipeline.compute_scores(biometric_references, sample_sets)
pass
# B = bob.io.base.load(p)
# pipeline.biometric_algorithm
if dask_client is not None: if dask_client is not None:
dask_client.shutdown() dask_client.shutdown()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment