From 4b238a25903d8abf9fb0b94a42094c9bb700e648 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 19 Jun 2020 09:08:41 +0200
Subject: [PATCH] Comparison app

---
 bob/bio/base/script/compare_samples.py        | 100 ++++++++++++++++++
 bob/bio/base/script/vanilla_biometrics.py     |  33 ++++--
 .../base/script/vanilla_biometrics_ztnorm.py  |  14 +--
 setup.py                                      |   1 +
 4 files changed, 129 insertions(+), 19 deletions(-)
 create mode 100644 bob/bio/base/script/compare_samples.py

diff --git a/bob/bio/base/script/compare_samples.py b/bob/bio/base/script/compare_samples.py
new file mode 100644
index 00000000..bab5de05
--- /dev/null
+++ b/bob/bio/base/script/compare_samples.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
+
+
+"""Executes biometric pipeline"""
+
+import click
+
+from bob.extension.scripts.click_helper import (
+    verbosity_option,
+    ResourceOption,
+    ConfigCommand,
+)
+import bob.io.base
+import bob.io.image
+
+import logging
+import os
+import itertools
+import dask.bag
+
+# from bob.bio.base.pipelines.vanilla_biometrics import (
+#    VanillaBiometricsPipeline,
+#    BioAlgorithmCheckpointWrapper,
+#    BioAlgorithmDaskWrapper,
+#    checkpoint_vanilla_biometrics,
+#    dask_vanilla_biometrics,
+#    dask_get_partition_size,
+#    FourColumnsScoreWriter,
+#    CSVScoreWriter,
+# )
+# from dask.delayed import Delayed
+# import pkg_resources
+from bob.extension.config import load as chain_load
+from bob.pipelines.utils import isinstance_nested
+from bob.bio.base.utils import get_resource_filename
+from .vanilla_biometrics import compute_scores, load_database_pipeline
+from bob.pipelines import Sample, SampleSet
+
+
+logger = logging.getLogger(__name__)
+
+
+EPILOG = """\b
+
+
+ Command line examples\n
+ -----------------------
+
+
+"""
+
+
+@click.command(epilog=EPILOG)
+@click.argument("samples", nargs=-1)
+@click.option(
+    "--pipeline",
+    "-p",
+    required=True,
+    cls=ResourceOption,
+    entry_point_group="bob.bio.pipeline",
+    help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
+)
+@click.option(
+    "--dask-client",
+    "-l",
+    required=False,
+    cls=ResourceOption,
+    help="Dask client for the execution of the pipeline.",
+)
+@verbosity_option(cls=ResourceOption)
+def compare_samples(
+    samples, pipeline, dask_client, **kwargs,
+):
+    """Compare several samples all vs all using one vanilla biometrics pipeline
+
+    """
+
+    if len(samples) == 1:
+        raise ValueError(
+            "It's necessary to have at least two samples for the comparison"
+        )
+
+    sample_sets = [
+        SampleSet([Sample(bob.io.base.load(s), key=str(i))], key=str(i))
+        for i, s in enumerate(samples)
+    ]
+
+    import ipdb; ipdb.set_trace()
+    for e in sample_sets:
+        biometric_references = pipeline.create_biometric_reference([e])
+        scores = pipeline.compute_scores(biometric_references, sample_sets)
+        pass
+
+    #    B = bob.io.base.load(p)
+    #    pipeline.biometric_algorithm
+
+    if dask_client is not None:
+        dask_client.shutdown()
diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py
index 74be6c84..88054034 100644
--- a/bob/bio/base/script/vanilla_biometrics.py
+++ b/bob/bio/base/script/vanilla_biometrics.py
@@ -74,16 +74,32 @@ def post_process_scores(pipeline, scores, path):
     return pipeline.post_process(writed_scores, path)
 
 
+def load_database_pipeline(database, pipeline):
+    # It's necessary to chain load 2 resources together
+    pipeline_config = get_resource_filename(pipeline, "bob.bio.pipeline")
+
+    if database is None:
+        vanilla_pipeline = chain_load([pipeline_config])
+        if hasattr(vanilla_pipeline, "database"):
+            return vanilla_pipeline.database, vanilla_pipeline.pipeline
+        else:
+            raise ValueError("Database was not set. Please look in `bob bio pipelines vanilla-biometrics --help` for more information")
+    else:
+        database_config = get_resource_filename(database, "bob.bio.database")
+        vanilla_pipeline = chain_load([database_config, pipeline_config])
+        return vanilla_pipeline.database, vanilla_pipeline.pipeline
+
+
 @click.command(
     entry_point_group="bob.bio.pipeline.config", cls=ConfigCommand, epilog=EPILOG,
 )
 @click.option(
-    "--pipeline", "-p", required=True, help="Vanilla biometrics pipeline",
+    "--pipeline", "-p", required=True, help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
 )
 @click.option(
     "--database",
     "-d",
-    required=True,
+    required=False,
     help="Biometric Database connector (class that implements the methods: `background_model_samples`, `references` and `probes`)",
 )
 @click.option(
@@ -185,15 +201,12 @@ def vanilla_biometrics(
     if not os.path.exists(output):
         os.makedirs(output, exist_ok=True)
 
-    # It's necessary to chain load 2 resources together
-    pipeline_config = get_resource_filename(pipeline, "bob.bio.pipeline")
-    database_config = get_resource_filename(database, "bob.bio.database")
-    vanilla_pipeline = chain_load([database_config, pipeline_config])
-    dask_client = chain_load([dask_client]).dask_client
-
     # Picking the resources
-    database = vanilla_pipeline.database
-    pipeline = vanilla_pipeline.pipeline
+    database, pipeline = load_database_pipeline(database, pipeline)
+
+    if dask_client is not None:
+        dask_client = chain_load([dask_client]).dask_client
+
     if write_metadata_scores:
         pipeline.score_writer = CSVScoreWriter(os.path.join(output, "./tmp"))
     else:
diff --git a/bob/bio/base/script/vanilla_biometrics_ztnorm.py b/bob/bio/base/script/vanilla_biometrics_ztnorm.py
index 57f5bbb9..ce00590a 100644
--- a/bob/bio/base/script/vanilla_biometrics_ztnorm.py
+++ b/bob/bio/base/script/vanilla_biometrics_ztnorm.py
@@ -34,7 +34,7 @@ from dask.delayed import Delayed
 from bob.bio.base.utils import get_resource_filename
 from bob.extension.config import load as chain_load
 from bob.pipelines.utils import isinstance_nested
-from .vanilla_biometrics import compute_scores, post_process_scores
+from .vanilla_biometrics import compute_scores, post_process_scores, load_database_pipeline
 import copy
 
 logger = logging.getLogger(__name__)
@@ -67,7 +67,7 @@ EPILOG = """\b
     entry_point_group="bob.pipelines.config", cls=ConfigCommand, epilog=EPILOG,
 )
 @click.option(
-    "--pipeline", "-p", required=True, help="An entry point or a configuration file containing a `VanillaBiometricsPipeline`.",
+    "--pipeline", "-p", required=True, help="Vanilla biometrics pipeline composed of a scikit-learn Pipeline and a BioAlgorithm",
 )
 @click.option(
     "--database",
@@ -199,16 +199,12 @@ def vanilla_biometrics_ztnorm(
         os.makedirs(output, exist_ok=True)
 
     # It's necessary to chain load 2 resources together
-    pipeline_config = get_resource_filename(pipeline, "bob.bio.pipeline")
-    database_config = get_resource_filename(database, "bob.bio.database")
-    vanilla_pipeline = chain_load([database_config, pipeline_config])
+    # Picking the resources
+    database, pipeline = load_database_pipeline(database, pipeline)
+
     if dask_client is not None:
         dask_client = chain_load([dask_client]).dask_client
 
-    # Picking the resources
-    database = vanilla_pipeline.database
-    pipeline = vanilla_pipeline.pipeline
-
     if write_metadata_scores:
         pipeline.score_writer = CSVScoreWriter(os.path.join(output, "./tmp"))
     else:
diff --git a/setup.py b/setup.py
index 57102217..70ff0ebe 100644
--- a/setup.py
+++ b/setup.py
@@ -133,6 +133,7 @@ setup(
         'baseline          = bob.bio.base.script.baseline:baseline',
         'sort              = bob.bio.base.script.sort:sort',
         'pipelines         = bob.bio.base.script.pipelines:pipelines',
+        'compare-samples   = bob.bio.base.script.compare_samples:compare_samples',
       ],
 
       # annotators
-- 
GitLab