From 2625d541679f4bcce246f04d9e63392037ea92ca Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Fri, 20 Mar 2020 12:22:51 +0100
Subject: [PATCH] Finished legacy Mixins

---
 .../base/config/baselines/lda_atnt_legacy.py  |  3 +-
 bob/bio/base/mixins/legacy.py                 | 93 ++++++++++---------
 2 files changed, 52 insertions(+), 44 deletions(-)

diff --git a/bob/bio/base/config/baselines/lda_atnt_legacy.py b/bob/bio/base/config/baselines/lda_atnt_legacy.py
index 76e859cd..a60c2cd6 100644
--- a/bob/bio/base/config/baselines/lda_atnt_legacy.py
+++ b/bob/bio/base/config/baselines/lda_atnt_legacy.py
@@ -67,7 +67,8 @@ extractor = Pipeline(
         ),
     ]
 )
-# extractor = dask_it(extractor)
+
+extractor = dask_it(extractor)
 
 from bob.bio.base.pipelines.vanilla_biometrics.biometric_algorithm import (
     Distance,
diff --git a/bob/bio/base/mixins/legacy.py b/bob/bio/base/mixins/legacy.py
index 6f34efc6..63e03b90 100644
--- a/bob/bio/base/mixins/legacy.py
+++ b/bob/bio/base/mixins/legacy.py
@@ -11,11 +11,15 @@ from bob.pipelines.mixins import CheckpointMixin, SampleMixin
 from sklearn.base import TransformerMixin, BaseEstimator
 from sklearn.utils.validation import check_array
 from bob.pipelines.sample import Sample, DelayedSample, SampleSet
+from bob.pipelines.utils import is_picklable
 import numpy
 import logging
 import os
+import bob.io.base
+import functools
 logger = logging.getLogger(__name__)
 
+
 def scikit_to_bob_supervised(X, Y):
     """
     Given an input data ready for :py:method:`scikit.estimator.BaseEstimator.fit`,
@@ -85,30 +89,24 @@ class LegacyProcessorMixin(TransformerMixin):
 
 from bob.pipelines.mixins import CheckpointMixin, SampleMixin
 class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
-    """Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm` and
+    """Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm`
     
-    LegacyAlgorithmrMixin.fit maps :py:method:`bob.bio.base.algorithm.Algoritm.train_projector`
+    :py:method:`LegacyAlgorithmrMixin.fit` maps to :py:method:`bob.bio.base.algorithm.Algoritm.train_projector`
 
-    LegacyAlgorithmrMixin.transform maps :py:method:`bob.bio.base.algorithm.Algoritm.project`
+    :py:method:`LegacyAlgorithmrMixin.transform` maps :py:method:`bob.bio.base.algorithm.Algoritm.project`
 
-    THIS HAS TO BE SAMPABLE AND CHECKPOINTABLE
+    .. warning THIS HAS TO BE SAMPABLE AND CHECKPOINTABLE
 
 
     Example
     -------
 
-        Wrapping preprocessor with functtools
-        >>> from bob.bio.base.mixins.legacy import LegacyProcessorMixin
-        >>> from bob.bio.face.preprocessor import FaceCrop
+        Wrapping LDA algorithm with functtools
+        >>> from bob.bio.base.mixins.legacy import LegacyAlgorithmMixin
+        >>> from bob.bio.base.algorithm import LDA
         >>> import functools
-        >>> transformer = LegacyProcessorMixin(functools.partial(FaceCrop, cropped_image_size=(10,10)))
+        >>> transformer = LegacyAlgorithmMixin(functools.partial(LDA, use_pinv=True, pca_subspace_dimension=0.90))
 
-    Example
-    -------
-        Wrapping extractor 
-        >>> from bob.bio.base.mixins.legacy import LegacyProcessorMixin
-        >>> from bob.bio.face.extractor import Linearize
-        >>> transformer = LegacyProcessorMixin(Linearize)
 
 
     Parameters
@@ -121,11 +119,13 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
     def __init__(self, callable=None, **kwargs):
         super().__init__(**kwargs)
         self.callable = callable
-        self.instance = None
-        self.projector_file = os.path.join(self.model_path, "Projector.hdf5")
+        self.instance = None        
+        self.projector_file = None
+
 
     def fit(self, X, y=None, **fit_params):
         
+        self.projector_file = os.path.join(self.model_path, "Projector.hdf5")
         if os.path.exists(self.projector_file):
             return self
 
@@ -147,6 +147,21 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
 
     def transform(self, X):
 
+        def _project_save_sample(sample):
+            # Project
+            projected_data = self.instance.project(sample.data)
+
+            #Checkpointing
+            path = self.make_path(sample)
+            bob.io.base.create_directories_safe(os.path.dirname(path))
+            f = bob.io.base.HDF5File(path, "w")
+
+            self.instance.write_feature(projected_data, f)
+            reader = self._get_reader(self.instance.read_feature, path)
+
+            return DelayedSample(reader, parent=sample)
+
+        self.projector_file = os.path.join(self.model_path, "Projector.hdf5")
         if not isinstance(X, list):
             raise ValueError("It's expected a list, not %s" % type(X))
 
@@ -155,41 +170,33 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
             self.instance = self.callable()
         self.instance.load_projector(self.projector_file)
 
-        import ipdb; ipdb.set_trace()
-
         if isinstance(X[0], Sample) or isinstance(X[0], DelayedSample):
-            #samples = []
-            for s in X:
-                projected_data = self.instance.project(s.data)
-        
-            #raw_X = [s.data for s in X]
-        elif isinstance(X[0], SampleSet):
+            samples = []
+            for sample in X:
+                samples.append(_project_save_sample(sample))
+            return samples
 
+        elif isinstance(X[0], SampleSet):
+            # Projecting and checkpointing sampleset
             sample_sets = []
             for sset in X:
-
                 samples = []
                 for sample in sset.samples:
+                    samples.append(_project_save_sample(sample))
+                sample_sets.append(SampleSet(samples=samples, parent=sset))
+            return sample_sets
 
-                    # Project
-                    projected_data = self.instance.project(sample.data)
-
-                    #Checkpointing
-                    path = self.make_path(sample)
-                    self.instance.write_feature(path)
-
-                    samples.append(DelayedSample())
-
-
-                    pass
-                    #bob.io.base.save(projected_data)
-
-
-
-
-            #raw_X = [x.data for s in X for x in s.samples]
         else:
             raise ValueError("Type not allowed %s" % type(X[0]))
 
 
-        return self.instance.project(raw_X)
+    def _get_reader(self, reader, path):
+        if(is_picklable(self.instance.read_feature)):
+            return functools.partial(reader, path)
+        else:
+            logger.warning(
+                        f"The method {reader} is not picklable. Shiping its unbounded method to `DelayedSample`."
+                    )
+            reader = reader.__func__  # The reader object might not be picklable
+            return functools.partial(reader, None, path)
+
-- 
GitLab