Skip to content
Snippets Groups Projects
Commit 2625d541 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira Committed by Amir MOHAMMADI
Browse files

Finished legacy Mixins

parent de935b20
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
......@@ -67,7 +67,8 @@ extractor = Pipeline(
),
]
)
# extractor = dask_it(extractor)
extractor = dask_it(extractor)
from bob.bio.base.pipelines.vanilla_biometrics.biometric_algorithm import (
Distance,
......
......@@ -11,11 +11,15 @@ from bob.pipelines.mixins import CheckpointMixin, SampleMixin
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.utils.validation import check_array
from bob.pipelines.sample import Sample, DelayedSample, SampleSet
from bob.pipelines.utils import is_picklable
import numpy
import logging
import os
import bob.io.base
import functools
logger = logging.getLogger(__name__)
def scikit_to_bob_supervised(X, Y):
"""
Given an input data ready for :py:method:`scikit.estimator.BaseEstimator.fit`,
......@@ -85,30 +89,24 @@ class LegacyProcessorMixin(TransformerMixin):
from bob.pipelines.mixins import CheckpointMixin, SampleMixin
class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
"""Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm` and
"""Class that wraps :py:class:`bob.bio.base.algorithm.Algoritm`
LegacyAlgorithmrMixin.fit maps :py:method:`bob.bio.base.algorithm.Algoritm.train_projector`
:py:method:`LegacyAlgorithmrMixin.fit` maps to :py:method:`bob.bio.base.algorithm.Algoritm.train_projector`
LegacyAlgorithmrMixin.transform maps :py:method:`bob.bio.base.algorithm.Algoritm.project`
:py:method:`LegacyAlgorithmrMixin.transform` maps :py:method:`bob.bio.base.algorithm.Algoritm.project`
THIS HAS TO BE SAMPABLE AND CHECKPOINTABLE
.. warning THIS HAS TO BE SAMPABLE AND CHECKPOINTABLE
Example
-------
Wrapping preprocessor with functtools
>>> from bob.bio.base.mixins.legacy import LegacyProcessorMixin
>>> from bob.bio.face.preprocessor import FaceCrop
Wrapping LDA algorithm with functtools
>>> from bob.bio.base.mixins.legacy import LegacyAlgorithmMixin
>>> from bob.bio.base.algorithm import LDA
>>> import functools
>>> transformer = LegacyProcessorMixin(functools.partial(FaceCrop, cropped_image_size=(10,10)))
>>> transformer = LegacyAlgorithmMixin(functools.partial(LDA, use_pinv=True, pca_subspace_dimension=0.90))
Example
-------
Wrapping extractor
>>> from bob.bio.base.mixins.legacy import LegacyProcessorMixin
>>> from bob.bio.face.extractor import Linearize
>>> transformer = LegacyProcessorMixin(Linearize)
Parameters
......@@ -121,11 +119,13 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
def __init__(self, callable=None, **kwargs):
super().__init__(**kwargs)
self.callable = callable
self.instance = None
self.projector_file = os.path.join(self.model_path, "Projector.hdf5")
self.instance = None
self.projector_file = None
def fit(self, X, y=None, **fit_params):
self.projector_file = os.path.join(self.model_path, "Projector.hdf5")
if os.path.exists(self.projector_file):
return self
......@@ -147,6 +147,21 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
def transform(self, X):
def _project_save_sample(sample):
# Project
projected_data = self.instance.project(sample.data)
#Checkpointing
path = self.make_path(sample)
bob.io.base.create_directories_safe(os.path.dirname(path))
f = bob.io.base.HDF5File(path, "w")
self.instance.write_feature(projected_data, f)
reader = self._get_reader(self.instance.read_feature, path)
return DelayedSample(reader, parent=sample)
self.projector_file = os.path.join(self.model_path, "Projector.hdf5")
if not isinstance(X, list):
raise ValueError("It's expected a list, not %s" % type(X))
......@@ -155,41 +170,33 @@ class LegacyAlgorithmMixin(CheckpointMixin,SampleMixin,BaseEstimator):
self.instance = self.callable()
self.instance.load_projector(self.projector_file)
import ipdb; ipdb.set_trace()
if isinstance(X[0], Sample) or isinstance(X[0], DelayedSample):
#samples = []
for s in X:
projected_data = self.instance.project(s.data)
#raw_X = [s.data for s in X]
elif isinstance(X[0], SampleSet):
samples = []
for sample in X:
samples.append(_project_save_sample(sample))
return samples
elif isinstance(X[0], SampleSet):
# Projecting and checkpointing sampleset
sample_sets = []
for sset in X:
samples = []
for sample in sset.samples:
samples.append(_project_save_sample(sample))
sample_sets.append(SampleSet(samples=samples, parent=sset))
return sample_sets
# Project
projected_data = self.instance.project(sample.data)
#Checkpointing
path = self.make_path(sample)
self.instance.write_feature(path)
samples.append(DelayedSample())
pass
#bob.io.base.save(projected_data)
#raw_X = [x.data for s in X for x in s.samples]
else:
raise ValueError("Type not allowed %s" % type(X[0]))
return self.instance.project(raw_X)
def _get_reader(self, reader, path):
if(is_picklable(self.instance.read_feature)):
return functools.partial(reader, path)
else:
logger.warning(
f"The method {reader} is not picklable. Shiping its unbounded method to `DelayedSample`."
)
reader = reader.__func__ # The reader object might not be picklable
return functools.partial(reader, None, path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment