From 90dc9a6bbb058a5fce20a97b8a0694d449732040 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Tue, 17 Mar 2020 17:50:50 +0100
Subject: [PATCH] Implemented checkpoint mechanism

---
 bob/bio/base/config/baselines/pca_atnt.py     |  6 +-
 .../vanilla_biometrics/biometric_algorithm.py | 60 ++++++++++++++-----
 bob/bio/base/script/vanilla_biometrics.py     |  3 +-
 3 files changed, 50 insertions(+), 19 deletions(-)

diff --git a/bob/bio/base/config/baselines/pca_atnt.py b/bob/bio/base/config/baselines/pca_atnt.py
index 1aad9f24..981e7fe6 100644
--- a/bob/bio/base/config/baselines/pca_atnt.py
+++ b/bob/bio/base/config/baselines/pca_atnt.py
@@ -25,6 +25,6 @@ extractor = Pipeline(steps=[('0',CheckpointSampleLinearize(features_dir="./examp
 #extractor = dask_it(extractor)
 
 from bob.bio.base.pipelines.vanilla_biometrics.biometric_algorithm import Distance, BiometricAlgorithmCheckpointMixin
-class CheckpointDistance(BiometricAlgorithmCheckpointMixin, Distance):  pass
-algorithm = CheckpointDistance(features_dir="./example/models")
-#algorithm = Distance()
+#class CheckpointDistance(BiometricAlgorithmCheckpointMixin, Distance):  pass
+#algorithm = CheckpointDistance(features_dir="./example/")
+algorithm = Distance()
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/biometric_algorithm.py b/bob/bio/base/pipelines/vanilla_biometrics/biometric_algorithm.py
index 53316312..35239aa3 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/biometric_algorithm.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/biometric_algorithm.py
@@ -66,7 +66,7 @@ class BiometricAlgorithm(object):
         return Sample(self.enroll(data), parent=sampleset)
 
 
-    def enroll(self, data,  **kwargs):
+    def enroll(self, data, extractor=None, **kwargs):
         """
         It handles the creation of ONE biometric reference for the vanilla ppipeline
 
@@ -135,6 +135,7 @@ class BiometricAlgorithm(object):
         for subprobe_id, (s, parent) in enumerate(zip(data, sampleset.samples)):
             # Creating one sample per comparison
             subprobe_scores = []
+
             for ref in [r for r in biometric_references if r.key in sampleset.references]:
                 subprobe_scores.append(
                     Sample(self.score(ref.data, s, extractor), parent=ref)
@@ -190,40 +191,57 @@ class BiometricAlgorithmCheckpointMixin(CheckpointMixin):
 
     """
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.biometric_reference_dir = os.path.join(self.features_dir, "biometric_references")
+        self.score_dir = os.path.join(self.features_dir, "scores")
+
+
+    def save(self, sample, path):
+        return bob.io.base.save(sample.data, path, create_directories=True)
+
 
     def _enroll_sample_set(self, sampleset):
         """
         Enroll a sample set with checkpointing
         """
 
-
-        path = self.make_path(sampleset)
+        # Amending `models` directory
+        path = os.path.join(self.biometric_reference_dir, str(sampleset.key) + self.extension)
         if path is None or not os.path.isfile(path):
 
             # Enrolling the sample
             enrolled_sample = super()._enroll_sample_set(sampleset)
 
             # saving the new sample
-            self.save(enrolled_sample)
+            self.save(enrolled_sample, path)
 
             # Dealaying it.
-            # This seems inefficient, but it's crucial with large datasets
-            enrolled_sample = DelayedSample(functools.partial(bob.io.base.load, path), enrolled_sample)
+            # This seems inefficient, but it's crucial for large datasets
+            delayed_enrolled_sample = DelayedSample(functools.partial(bob.io.base.load, path), enrolled_sample)
 
         else:
             # If sample already there, just load
-            enrolled_sample = self.load(path)
+            delayed_enrolled_sample = self.load(path)
+            delayed_enrolled_sample.key = sampleset.key
 
-        return enrolled_sample
 
+        return delayed_enrolled_sample
 
-    #def _score_sample_set(self, sampleset, biometric_references, extractor):
-    #   """Given a sampleset for probing, compute the scores and retures a sample set with the scores
-    #    """
 
-    #    scored_sample = 
+    def _score_sample_set(self, sampleset, biometric_references, extractor):
+        """Given a sampleset for probing, compute the scores and retures a sample set with the scores
+        """
+        # Computing score
+        scored_sample_set = super()._score_sample_set(sampleset, biometric_references, extractor)
 
-    #    return subprobe
+        # Checkpointing score
+        path = os.path.join(self.score_dir, str(sampleset.key) + ".txt")
+        bob.io.base.create_directories_safe(os.path.dirname(path))
+
+        delayed_scored_sample = save_scores_four_columns(path, scored_sample_set)
+        scored_sample_set.samples = [delayed_scored_sample]
+        return scored_sample_set
 
 
 import scipy.spatial.distance
@@ -259,7 +277,7 @@ class Distance(BiometricAlgorithm):
         return numpy.mean(enroll_features, axis=0)
 
 
-    def score(self, model, probe,  **kwargs):
+    def score(self, model, probe, extractor=None, **kwargs):
         """score(model, probe) -> float
 
         Computes the distance of the model to the probe using the distance function specified in the constructor.
@@ -283,3 +301,17 @@ class Distance(BiometricAlgorithm):
         probe = probe.flatten()
         # return the negative distance (as a similarity measure)
         return self.factor * self.distance_function(model, probe)
+
+
+def save_scores_four_columns(path, probe):
+    """
+    Write scores in the four columns format
+    """
+    
+    with open(path, "w") as f:
+        for biometric_reference in probe.samples:
+            line = "{0} {1} {2} {3}\n".format(biometric_reference.key, probe.key, probe.path, biometric_reference.data)
+            f.write(line)
+
+    return DelayedSample(functools.partial(open, path))
+
diff --git a/bob/bio/base/script/vanilla_biometrics.py b/bob/bio/base/script/vanilla_biometrics.py
index d0be19a6..bbf193c1 100644
--- a/bob/bio/base/script/vanilla_biometrics.py
+++ b/bob/bio/base/script/vanilla_biometrics.py
@@ -182,10 +182,9 @@ def vanilla_biometrics(
             )
         
             if dask_client is not None:
-                result = result.compute(scheduler=dask_client)
+                #result = result.compute(scheduler=dask_client)
                 result = result.compute(scheduler="single-threaded")
 
-            #import ipdb; ipdb.set_trace()
             for probe in result:
                 for sample in probe.samples:
                     
-- 
GitLab