From ebc2cb4c0e70f579ef5f9c50419b7177c92197eb Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 1 May 2020 19:37:41 +0200
Subject: [PATCH] Patched with new VanillaBiometricsPipeline

---
 bob/bio/base/script/vanilla_biometrics.py | 77 ++++++++++++-----------
 1 file changed, 41 insertions(+), 36 deletions(-)

diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py
index ddcd841d..b5d76595 100644
--- a/bob/bio/base/script/vanilla_biometrics.py
+++ b/bob/bio/base/script/vanilla_biometrics.py
@@ -14,6 +14,14 @@ from bob.extension.scripts.click_helper import (
 )
 
 import logging
+import os
+import itertools
+import dask.bag
+from bob.bio.base.pipelines.vanilla_biometrics import (
+    VanillaBiometricsPipeline,
+    BioAlgorithmCheckpointWrapper,
+)
+
 
 logger = logging.getLogger(__name__)
 
@@ -96,9 +104,7 @@ TODO: Work out this help
     help="Name of output directory",
 )
 @verbosity_option(cls=ResourceOption)
-def vanilla_biometrics(
-    pipeline, database, dask_client, groups, output, **kwargs
-):
+def vanilla_biometrics(pipeline, database, dask_client, groups, output, **kwargs):
     """Runs the simplest biometrics pipeline.
 
     Such pipeline consists into three sub-pipelines.
@@ -143,43 +149,42 @@ def vanilla_biometrics(
 
     """
 
-    from bob.bio.base.pipelines.vanilla_biometrics.pipeline import VanillaBiometrics
-    import dask.bag
-    import itertools
-    import os
-    from bob.pipelines.sample import Sample, DelayedSample
-
     if not os.path.exists(output):
-        os.makedirs(output, exist_ok=True)    
+        os.makedirs(output, exist_ok=True)
 
     for group in groups:
 
-        with open(os.path.join(output, f"scores-{group}"), "w") as f:
-            biometric_references = database.references(group=group)
-
-            logger.info(f"Running vanilla biometrics for group {group}")
-
-            allow_scoring_with_all_biometric_references = (
-                database.allow_scoring_with_all_biometric_references
-                if hasattr(database, "allow_scoring_with_all_biometric_references")
-                else False
-            )
-
-            result = pipeline(database.background_model_samples(),
-                              biometric_references,
-                              database.probes(group=group),
-                              allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references
-                              )
-
-            if isinstance(result, dask.bag.core.Bag):
-                if dask_client is not None:
-                    result = result.compute(scheduler=dask_client)
-                else:
-                    logger.warning(
-                        "`dask_client` not set. Your pipeline will run locally"
-                    )
-                    result = result.compute(scheduler="single-threaded")
-
+        score_file_name = os.path.join(output, f"scores-{group}.txt")
+        biometric_references = database.references(group=group)
+
+        logger.info(f"Running vanilla biometrics for group {group}")
+
+        allow_scoring_with_all_biometric_references = (
+            database.allow_scoring_with_all_biometric_references
+            if hasattr(database, "allow_scoring_with_all_biometric_references")
+            else False
+        )
+
+        result = pipeline(
+            database.background_model_samples(),
+            biometric_references,
+            database.probes(group=group),
+            allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
+        )
+
+        if isinstance(result, dask.bag.core.Bag):
+            if dask_client is not None:
+                result = result.compute(scheduler=dask_client)
+            else:
+                logger.warning(
+                    "`dask_client` not set. Your pipeline will run locally"
+                )
+                result = result.compute(scheduler="single-threaded")
+
+        # Check if there's a score writer hooked in
+        if isinstance(pipeline.biometric_algorithm, BioAlgorithmCheckpointWrapper):
+            pipeline.biometric_algorithm.score_writer.concatenate_write_scores(result, score_file_name)
+        else:
             # Flatting out the list
             result = itertools.chain(*result)
             for probe in result:
-- 
GitLab