From 2625d541679f4bcce246f04d9e63392037ea92ca Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Fri, 20 Mar 2020 12:22:51 +0100 Subject: [PATCH] Finished legacy Mixins --- .../base/config/baselines/lda_atnt_legacy.py | 3 +- bob/bio/base/mixins/legacy.py | 93 ++++++++++--------- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/bob/bio/base/config/baselines/lda_atnt_legacy.py b/bob/bio/base/config/baselines/lda_atnt_legacy.py index 76e859cd..a60c2cd6 100644 --- a/bob/bio/base/config/baselines/lda_atnt_legacy.py +++ b/bob/bio/base/config/baselines/lda_atnt_legacy.py @@ -67,7 +67,8 @@ extractor = Pipeline( ), ] ) -# extractor = dask_it(extractor) + +extractor = dask_it(extractor) from bob.bio.base.pipelines.vanilla_biometrics.biometric_algorithm import ( Distance, diff --git a/bob/bio/base/mixins/legacy.py b/bob/bio/base/mixins/legacy.py index 6f34efc6..63e03b90 100644 --- a/bob/bio/base/mixins/legacy.py +++ b/bob/bio/base/mixins/legacy.py @@ -11,11 +11,15 @@ from bob.pipelines.mixins import CheckpointMixin, SampleMixin from sklearn.base import TransformerMixin, BaseEstimator from sklearn.utils.validation import check_array from bob.pipelines.sample import Sample, DelayedSample, SampleSet +from bob.pipelines.utils import is_picklable import numpy import logging import os +import bob.io.base +import functools logger = logging.getLogger(__name__) + def scikit_to_bob_supervised(X, Y): """ Given an input data ready for :py:method:`scikit.estimator.BaseEstimator.fit`, @@ -85,30 +89,24 @@ class LegacyProcessorMixin(TransformerMixin): from bob.pipelines.mixins import CheckpointMixin, SampleMixin class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator): - """Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm` and + """Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm` - LegacyAlgorithmrMixin.fit maps :py:method:`bob.bio.base.algorithm.Algoritm.train_projector` + :py:method:`LegacyAlgorithmrMixin.fit` maps to :py:method:`bob.bio.base.algorithm.Algoritm.train_projector` - LegacyAlgorithmrMixin.transform maps :py:method:`bob.bio.base.algorithm.Algoritm.project` + :py:method:`LegacyAlgorithmrMixin.transform` maps :py:method:`bob.bio.base.algorithm.Algoritm.project` - THIS HAS TO BE SAMPABLE AND CHECKPOINTABLE + .. warning THIS HAS TO BE SAMPABLE AND CHECKPOINTABLE Example ------- - Wrapping preprocessor with functtools - >>> from bob.bio.base.mixins.legacy import LegacyProcessorMixin - >>> from bob.bio.face.preprocessor import FaceCrop + Wrapping LDA algorithm with functtools + >>> from bob.bio.base.mixins.legacy import LegacyAlgorithmMixin + >>> from bob.bio.base.algorithm import LDA >>> import functools - >>> transformer = LegacyProcessorMixin(functools.partial(FaceCrop, cropped_image_size=(10,10))) + >>> transformer = LegacyAlgorithmMixin(functools.partial(LDA, use_pinv=True, pca_subspace_dimension=0.90)) - Example - ------- - Wrapping extractor - >>> from bob.bio.base.mixins.legacy import LegacyProcessorMixin - >>> from bob.bio.face.extractor import Linearize - >>> transformer = LegacyProcessorMixin(Linearize) Parameters @@ -121,11 +119,13 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator): def __init__(self, callable=None, **kwargs): super().__init__(**kwargs) self.callable = callable - self.instance = None - self.projector_file = os.path.join(self.model_path, "Projector.hdf5") + self.instance = None + self.projector_file = None + def fit(self, X, y=None, **fit_params): + self.projector_file = os.path.join(self.model_path, "Projector.hdf5") if os.path.exists(self.projector_file): return self @@ -147,6 +147,21 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator): def transform(self, X): + def _project_save_sample(sample): + # Project + projected_data = self.instance.project(sample.data) + + #Checkpointing + path = self.make_path(sample) + bob.io.base.create_directories_safe(os.path.dirname(path)) + f = bob.io.base.HDF5File(path, "w") + + self.instance.write_feature(projected_data, f) + reader = self._get_reader(self.instance.read_feature, path) + + return DelayedSample(reader, parent=sample) + + self.projector_file = os.path.join(self.model_path, "Projector.hdf5") if not isinstance(X, list): raise ValueError("It's expected a list, not %s" % type(X)) @@ -155,41 +170,33 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator): self.instance = self.callable() self.instance.load_projector(self.projector_file) - import ipdb; ipdb.set_trace() - if isinstance(X[0], Sample) or isinstance(X[0], DelayedSample): - #samples = [] - for s in X: - projected_data = self.instance.project(s.data) - - #raw_X = [s.data for s in X] - elif isinstance(X[0], SampleSet): + samples = [] + for sample in X: + samples.append(_project_save_sample(sample)) + return samples + elif isinstance(X[0], SampleSet): + # Projecting and checkpointing sampleset sample_sets = [] for sset in X: - samples = [] for sample in sset.samples: + samples.append(_project_save_sample(sample)) + sample_sets.append(SampleSet(samples=samples, parent=sset)) + return sample_sets - # Project - projected_data = self.instance.project(sample.data) - - #Checkpointing - path = self.make_path(sample) - self.instance.write_feature(path) - - samples.append(DelayedSample()) - - - pass - #bob.io.base.save(projected_data) - - - - - #raw_X = [x.data for s in X for x in s.samples] else: raise ValueError("Type not allowed %s" % type(X[0])) - return self.instance.project(raw_X) + def _get_reader(self, reader, path): + if(is_picklable(self.instance.read_feature)): + return functools.partial(reader, path) + else: + logger.warning( + f"The method {reader} is not picklable. Shiping its unbounded method to `DelayedSample`." + ) + reader = reader.__func__ # The reader object might not be picklable + return functools.partial(reader, None, path) + -- GitLab