From 0414d3be2b0656a69a513501f158a65699e13c72 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 22 May 2020 22:02:46 +0200
Subject: [PATCH] Polishing score serialization. It's impossible to do it in
 HDF5

---
 .../pipelines/vanilla_biometrics/wrappers.py  | 39 +++++++++++++------
 1 file changed, 28 insertions(+), 11 deletions(-)

diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
index 115c5ac9..157c0def 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
@@ -5,9 +5,10 @@ import dask
 import functools
 from .score_writers import FourColumnsScoreWriter
 from .abstract_classes import BioAlgorithm
-import pickle
 import bob.pipelines as mario
 import numpy as np
+import h5py
+import cloudpickle
 from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper
 
 
@@ -45,6 +46,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
         self.biometric_algorithm = biometric_algorithm
         self.force = force
         self._biometric_reference_extension = ".hdf5"
+        self._score_extension = ".pkl"
         self.base_dir = base_dir
 
     def enroll(self, enroll_features):
@@ -63,7 +65,9 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
 
     def write_scores(self, samples, path):
         os.makedirs(os.path.dirname(path), exist_ok=True)
-        open(path, "wb").write(pickle.dumps(samples))
+        # cleaning parent
+        with open(path, "wb") as f:
+            f.write(cloudpickle.dumps(samples))
 
     def _enroll_sample_set(self, sampleset):
         """
@@ -99,17 +103,21 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
         """
 
         def _load(path):
-            return pickle.loads(open(path, "rb").read())
+            return cloudpickle.loads(open(path, "rb").read())
+
+            #with h5py.File(path) as hdf5:
+            #    return hdf5_to_sample(hdf5)
 
         def _make_name(sampleset, biometric_references):
             # The score file name is composed by sampleset key and the
             # first 3 biometric_references
+            subject = str(sampleset.subject)
             name = str(sampleset.key)
             suffix = "_".join([str(s.key) for s in biometric_references[0:3]])
-            return name + suffix
+            return os.path.join(subject, name + suffix)
 
         path = os.path.join(
-            self.score_dir, _make_name(sampleset, biometric_references) + ".pkl"
+            self.score_dir, _make_name(sampleset, biometric_references) + self._score_extension
         )
 
         if self.force or not os.path.exists(path):
@@ -123,11 +131,12 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
             self.write_scores(scored_sample_set.samples, path)
 
             scored_sample_set = SampleSet(
-                [DelayedSample(functools.partial(_load, path), parent=sampleset)],
+                DelayedSample(functools.partial(_load, path), parent=sampleset),
                 parent=sampleset,
             )
         else:
-            scored_sample_set = SampleSet(_load(path), parent=sampleset)
+            samples = _load(path)
+            scored_sample_set = SampleSet(samples, parent=sampleset)
 
         return scored_sample_set
 
@@ -182,7 +191,7 @@ class BioAlgorithmDaskWrapper(BioAlgorithm):
         )
 
 
-def dask_vanilla_biometrics(pipeline, npartitions=None):
+def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
     """
     Given a :any:`VanillaBiometrics`, wraps :any:`VanillaBiometrics.transformer` and
     :any:`VanillaBiometrics.biometric_algorithm` to be executed with dask
@@ -195,6 +204,9 @@ def dask_vanilla_biometrics(pipeline, npartitions=None):
 
     npartitions: int
        Number of partitions for the initial :any:`dask.bag`
+
+    partition_size: int
+       Size of the partition for the initial :any:`dask.bag`
     """
 
     if isinstance(pipeline, ZTNormPipeline):
@@ -207,9 +219,14 @@ def dask_vanilla_biometrics(pipeline, npartitions=None):
 
     else:
 
-        pipeline.transformer = mario.wrap(
-            ["dask"], pipeline.transformer, npartitions=npartitions
-        )
+        if partition_size is None:
+            pipeline.transformer = mario.wrap(
+                ["dask"], pipeline.transformer, npartitions=npartitions
+            )
+        else:
+            pipeline.transformer = mario.wrap(
+                ["dask"], pipeline.transformer, partition_size=partition_size
+            )
         pipeline.biometric_algorithm = BioAlgorithmDaskWrapper(
             pipeline.biometric_algorithm
         )
-- 
GitLab