Skip to content
Snippets Groups Projects
Commit d594ae7f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

make samplesmixin more generic and add some useful transformers

parent 3157dbea
Branches
Tags
1 merge request!9SampleMixin now accepts extra arguments
Pipeline #38347 failed
# see https://docs.python.org/3/library/pkgutil.html
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
...@@ -12,7 +12,14 @@ from sklearn.pipeline import Pipeline ...@@ -12,7 +12,14 @@ from sklearn.pipeline import Pipeline
from dask import delayed from dask import delayed
import dask.bag import dask.bag
def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None):
def estimator_dask_it(
o,
fit_tag=None,
transform_tag=None,
npartitions=None,
mix_for_each_step_in_pipelines=True,
):
""" """
Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` with Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` with
:py:class`DaskEstimatorMixin` :py:class`DaskEstimatorMixin`
...@@ -42,7 +49,7 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None): ...@@ -42,7 +49,7 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None):
Vanilla example Vanilla example
>>> pipeline = dask_it(pipeline) # Take some pipeline and make the methods `fit`and `transform` run over dask >>> pipeline = estimator_dask_it(pipeline) # Take some pipeline and make the methods `fit`and `transform` run over dask
>>> pipeline.fit(samples).compute() >>> pipeline.fit(samples).compute()
...@@ -50,12 +57,12 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None): ...@@ -50,12 +57,12 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None):
Hence, we can set the `dask.delayed.compute` method to place some Hence, we can set the `dask.delayed.compute` method to place some
delayeds to be executed in particular resources delayeds to be executed in particular resources
>>> pipeline = dask_it(pipeline, fit_tag=[(1, "GPU")]) # Take some pipeline and make the methods `fit`and `transform` run over dask >>> pipeline = estimator_dask_it(pipeline, fit_tag=[(1, "GPU")]) # Take some pipeline and make the methods `fit`and `transform` run over dask
>>> fit = pipeline.fit(samples) >>> fit = pipeline.fit(samples)
>>> fit.compute(resources=pipeline.dask_tags()) >>> fit.compute(resources=pipeline.dask_tags())
Taging estimator Taging estimator
>>> estimator = dask_it(estimator) >>> estimator = estimator_dask_it(estimator)
>>> transf = estimator.transform(samples) >>> transf = estimator.transform(samples)
>>> transf.compute(resources=estimator.dask_tags()) >>> transf.compute(resources=estimator.dask_tags())
...@@ -67,7 +74,7 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None): ...@@ -67,7 +74,7 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None):
""" """
resource_tags = dict() resource_tags = dict()
if isinstance(self, Pipeline): if isinstance(self, Pipeline):
for i in range(1,len(self.steps)): for i in range(1, len(self.steps)):
resource_tags.update(o[i].resource_tags) resource_tags.update(o[i].resource_tags)
else: else:
resource_tags.update(self.resource_tags) resource_tags.update(self.resource_tags)
...@@ -75,11 +82,15 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None): ...@@ -75,11 +82,15 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None):
return resource_tags return resource_tags
if isinstance(o, Pipeline): if isinstance(o, Pipeline):
#Adding a daskbag in the tail of the pipeline # Adding a daskbag in the tail of the pipeline
o.steps.insert(0, ('0', DaskBagMixin(npartitions=npartitions))) o.steps.insert(0, ("0", DaskBagMixin(npartitions=npartitions)))
# Patching dask_resources # Patching dask_resources
dasked = mix_me_up(DaskEstimatorMixin, o) dasked = mix_me_up(
DaskEstimatorMixin,
o,
mix_for_each_step_in_pipelines=mix_for_each_step_in_pipelines,
)
# Tagging each element in a pipeline # Tagging each element in a pipeline
if isinstance(o, Pipeline): if isinstance(o, Pipeline):
...@@ -97,12 +108,12 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None): ...@@ -97,12 +108,12 @@ def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None):
dasked.transform_tag = transform_tag dasked.transform_tag = transform_tag
# Bounding the method # Bounding the method
dasked.dask_tags = types.MethodType( _fetch_resource_tape, dasked ) dasked.dask_tags = types.MethodType(_fetch_resource_tape, dasked)
return dasked return dasked
def mix_me_up(bases, o): def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
""" """
Dynamically creates a new class from :any:`object` or :any:`class`. Dynamically creates a new class from :any:`object` or :any:`class`.
For instance, mix_me_up((A,B), class_c) is equal to `class ABC(A,B,C) pass:` For instance, mix_me_up((A,B), class_c) is equal to `class ABC(A,B,C) pass:`
...@@ -156,7 +167,7 @@ def mix_me_up(bases, o): ...@@ -156,7 +167,7 @@ def mix_me_up(bases, o):
# If it is a scikit pipeline, mixIN everything inside of # If it is a scikit pipeline, mixIN everything inside of
# Pipeline.steps # Pipeline.steps
if isinstance(o, Pipeline): if isinstance(o, Pipeline) and mix_for_each_step_in_pipelines:
# mixing all pipelines # mixing all pipelines
for i in range(len(o.steps)): for i in range(len(o.steps)):
# checking if it's not the bag transformer # checking if it's not the bag transformer
...@@ -174,7 +185,6 @@ def _is_estimator_stateless(estimator): ...@@ -174,7 +185,6 @@ def _is_estimator_stateless(estimator):
return estimator._get_tags()["stateless"] return estimator._get_tags()["stateless"]
class SampleMixin: class SampleMixin:
"""Mixin class to make scikit-learn estimators work in :any:`Sample`-based """Mixin class to make scikit-learn estimators work in :any:`Sample`-based
pipelines. pipelines.
...@@ -185,37 +195,42 @@ class SampleMixin: ...@@ -185,37 +195,42 @@ class SampleMixin:
https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects
""" """
def __init__(self, extra_arguments=None, supervised_fit=False, y_attribute_name=None, **kwargs):
def __init__(
self, transform_extra_arguments=None, fit_extra_arguments=None, **kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.extra_arguments = extra_arguments or [] self.transform_extra_arguments = transform_extra_arguments or tuple()
self.supervised_fit = supervised_fit self.fit_extra_arguments = fit_extra_arguments or tuple()
self.y_attribute_name = y_attribute_name
def transform(self, samples): def transform(self, samples):
#if not isinstance(samples, list): # Transform either samples or samplesets
# samples = [samples]
# Transform eith samples or samplesets
if isinstance(samples[0], Sample) or isinstance(samples[0], DelayedSample): if isinstance(samples[0], Sample) or isinstance(samples[0], DelayedSample):
kwargs = {arg: [getattr(s, arg) for s in samples] for arg in self.extra_arguments} kwargs = {
arg: [getattr(s, attr) for s in samples]
for arg, attr in self.transform_extra_arguments
}
features = super().transform([s.data for s in samples], **kwargs) features = super().transform([s.data for s in samples], **kwargs)
new_samples = [Sample(data, parent=s) for data, s in zip(features, samples)] new_samples = [Sample(data, parent=s) for data, s in zip(features, samples)]
return new_samples return new_samples
elif isinstance(samples[0], SampleSet): elif isinstance(samples[0], SampleSet):
return [SampleSet(self.transform(sset.samples), parent=sset) return [
for sset in samples] SampleSet(self.transform(sset.samples), parent=sset) for sset in samples
]
else: else:
raise ValueError("Type for sample not supported %s" % type(samples)) raise ValueError("Type for sample not supported %s" % type(samples))
def fit(self, samples, y=None): def fit(self, samples, y=None):
# IF THE SUPER METHOD IS NOT FITTABLE, # if the super method is not fittable,
# THERE'S NO REASON TO STACK THOSE SAMPLES # there's no reason to stack those samples
if( hasattr(super(), "fit")): if hasattr(super(), "fit"):
if self.supervised_fit: kwargs = {
y = [getattr(s, self.y_attribute_name) for s in samples] arg: [getattr(s, attr) for s in samples]
return super().fit([s.data for s in samples], y=y) for arg, attr in self.fit_extra_arguments
}
return super().fit([s.data for s in samples], **kwargs)
return self return self
...@@ -224,15 +239,24 @@ class CheckpointMixin: ...@@ -224,15 +239,24 @@ class CheckpointMixin:
"""Mixin class that allows :any:`Sample`-based estimators save their results into """Mixin class that allows :any:`Sample`-based estimators save their results into
disk.""" disk."""
def __init__(self, model_path=None, features_dir=None, extension=".h5", **kwargs): def __init__(
self,
model_path=None,
features_dir=None,
extension=".h5",
save_func=None,
load_func=None,
**kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.model_path = model_path self.model_path = model_path
self.features_dir = features_dir self.features_dir = features_dir
self.extension = extension self.extension = extension
self.save_func = save_func or bob.io.base.save
self.load_func = load_func or bob.io.base.load
def transform_one_sample(self, sample): def transform_one_sample(self, sample):
# Check if the sample is already processed. # Check if the sample is already processed.
path = self.make_path(sample) path = self.make_path(sample)
if path is None or not os.path.isfile(path): if path is None or not os.path.isfile(path):
...@@ -248,7 +272,6 @@ class CheckpointMixin: ...@@ -248,7 +272,6 @@ class CheckpointMixin:
samples = [self.transform_one_sample(s) for s in sample_set.samples] samples = [self.transform_one_sample(s) for s in sample_set.samples]
return SampleSet(samples, parent=sample_set) return SampleSet(samples, parent=sample_set)
def transform(self, samples): def transform(self, samples):
if not isinstance(samples, list): if not isinstance(samples, list):
raise ValueError("It's expected a list, not %s" % type(samples)) raise ValueError("It's expected a list, not %s" % type(samples))
...@@ -260,12 +283,11 @@ class CheckpointMixin: ...@@ -260,12 +283,11 @@ class CheckpointMixin:
else: else:
raise ValueError("Type not allowed %s" % type(samples[0])) raise ValueError("Type not allowed %s" % type(samples[0]))
def fit(self, samples, y=None): def fit(self, samples, y=None):
# IF THE SUPER METHOD IS NOT FITTABLE, # IF THE SUPER METHOD IS NOT FITTABLE,
# THERE'S NO REASON TO STACK THOSE SAMPLES # THERE'S NO REASON TO STACK THOSE SAMPLES
if( not hasattr(super(), "fit") ): if not hasattr(super(), "fit"):
return self return self
if self.model_path is not None and os.path.isfile(self.model_path): if self.model_path is not None and os.path.isfile(self.model_path):
...@@ -274,7 +296,6 @@ class CheckpointMixin: ...@@ -274,7 +296,6 @@ class CheckpointMixin:
super().fit(samples, y=y) super().fit(samples, y=y)
return self.save_model() return self.save_model()
def fit_transform(self, samples, y=None): def fit_transform(self, samples, y=None):
return self.fit(samples, y=y).transform(samples) return self.fit(samples, y=y).transform(samples)
...@@ -293,14 +314,13 @@ class CheckpointMixin: ...@@ -293,14 +314,13 @@ class CheckpointMixin:
def save(self, sample): def save(self, sample):
if isinstance(sample, Sample): if isinstance(sample, Sample):
path = self.make_path(sample) path = self.make_path(sample)
return bob.io.base.save(sample.data, path, create_directories=True) return self.save_func(sample.data, path, create_directories=True)
elif isinstance(sample, SampleSet): elif isinstance(sample, SampleSet):
for s in sample.samples: for s in sample.samples:
path = self.make_path(s) path = self.make_path(s)
return bob.io.base.save(s.data, path, create_directories=True) return self.save_func(s.data, path, create_directories=True)
else: else:
raise ValueError("Type for sample not supported %s" % type(sample) ) raise ValueError("Type for sample not supported %s" % type(sample))
def load(self, path): def load(self, path):
key = self.recover_key_from_path(path) key = self.recover_key_from_path(path)
...@@ -308,7 +328,7 @@ class CheckpointMixin: ...@@ -308,7 +328,7 @@ class CheckpointMixin:
# instead of a normal (preloaded) sample. This allows the next # instead of a normal (preloaded) sample. This allows the next
# phase to avoid loading it would it be unnecessary (e.g. next # phase to avoid loading it would it be unnecessary (e.g. next
# phase is already check-pointed) # phase is already check-pointed)
return DelayedSample(functools.partial(bob.io.base.load, path), key=key) return DelayedSample(functools.partial(self.load_func, path), key=key)
def load_model(self): def load_model(self):
if _is_estimator_stateless(self): if _is_estimator_stateless(self):
...@@ -343,6 +363,7 @@ class CheckpointSampleFunctionTransformer( ...@@ -343,6 +363,7 @@ class CheckpointSampleFunctionTransformer(
Furthermore, it makes it checkpointable Furthermore, it makes it checkpointable
""" """
pass pass
...@@ -386,7 +407,6 @@ class NonPicklableMixin: ...@@ -386,7 +407,6 @@ class NonPicklableMixin:
return self.instance.transform(X) return self.instance.transform(X)
class DaskEstimatorMixin: class DaskEstimatorMixin:
"""Wraps Scikit estimators into Daskable objects """Wraps Scikit estimators into Daskable objects
......
"""Base definition of sample""" """Base definition of sample"""
def samplesets_to_samples(samplesets):
"""
Given a list of :py:class:`SampleSet` break them in to a list of :py:class:`Sample` with its
corresponding key
This is supposed to fit the :py:meth:`sklearn.estimator.BaseEstimator.fit` where X and y are the inputs
Check here https://scikit-learn.org/stable/developers/develop.html for more info
Parameters
----------
samplesets: list
List of :py:class:`SampleSet
Return
------
X and y used in :py:meth:`sklearn.estimator.BaseEstimator.fit`
"""
# TODO: Is there a way to make this operation more efficient? numpy.arrays?
X = []
y= []
for s in samplesets:
X += s.samples
y += [s.key]
return X, y
def _copy_attributes(s, d): def _copy_attributes(s, d):
"""Copies attributes from a dictionary to self """Copies attributes from a dictionary to self
""" """
......
...@@ -14,7 +14,7 @@ from bob.pipelines.mixins import ( ...@@ -14,7 +14,7 @@ from bob.pipelines.mixins import (
DaskEstimatorMixin, DaskEstimatorMixin,
DaskBagMixin, DaskBagMixin,
mix_me_up, mix_me_up,
dask_it estimator_dask_it,
) )
from bob.pipelines.mixins import _is_estimator_stateless from bob.pipelines.mixins import _is_estimator_stateless
from sklearn.base import TransformerMixin, BaseEstimator from sklearn.base import TransformerMixin, BaseEstimator
...@@ -50,7 +50,7 @@ class DummyWithFit(TransformerMixin, BaseEstimator): ...@@ -50,7 +50,7 @@ class DummyWithFit(TransformerMixin, BaseEstimator):
if X.shape[1] != self.n_features_: if X.shape[1] != self.n_features_:
raise ValueError( raise ValueError(
"Shape of input is different from what was seen" "in `fit`" "Shape of input is different from what was seen" "in `fit`"
) )
return X @ self.model_ return X @ self.model_
...@@ -74,7 +74,7 @@ class DummyTransformer(TransformerMixin, BaseEstimator): ...@@ -74,7 +74,7 @@ class DummyTransformer(TransformerMixin, BaseEstimator):
# Input validation # Input validation
X = check_array(X) X = check_array(X)
# Check that the input is of the same shape as the one passed # Check that the input is of the same shape as the one passed
# during fit. # during fit.
return _offset_add_func(X) return _offset_add_func(X)
...@@ -218,7 +218,7 @@ def _build_transformer(path, i, picklable=True, dask_enabled=True): ...@@ -218,7 +218,7 @@ def _build_transformer(path, i, picklable=True, dask_enabled=True):
import functools import functools
if dask_enabled: if dask_enabled:
estimator_cls = dask_it(estimator_cls) estimator_cls = estimator_dask_it(estimator_cls)
return NonPicklableMixin( return NonPicklableMixin(
functools.partial( functools.partial(
...@@ -256,7 +256,7 @@ def test_checkpoint_transform_pipeline(): ...@@ -256,7 +256,7 @@ def test_checkpoint_transform_pipeline():
[(f"{i}", _build_transformer(d, i)) for i in range(offset)] [(f"{i}", _build_transformer(d, i)) for i in range(offset)]
) )
if dask_enabled: if dask_enabled:
pipeline = dask_it(pipeline) pipeline = estimator_dask_it(pipeline)
transformed_samples = pipeline.transform(samples_transform).compute( transformed_samples = pipeline.transform(samples_transform).compute(
scheduler="single-threaded" scheduler="single-threaded"
) )
...@@ -280,8 +280,8 @@ def test_checkpoint_fit_transform_pipeline(): ...@@ -280,8 +280,8 @@ def test_checkpoint_fit_transform_pipeline():
fitter = ("0", _build_estimator(d, 0)) fitter = ("0", _build_estimator(d, 0))
transformer = ("1", _build_transformer(d, 1)) transformer = ("1", _build_transformer(d, 1))
pipeline = Pipeline([fitter, transformer]) pipeline = Pipeline([fitter, transformer])
if dask_enabled: if dask_enabled:
pipeline = dask_it(pipeline, fit_tag=[(1, "GPU")]) pipeline = estimator_dask_it(pipeline, fit_tag=[(1, "GPU")])
pipeline = pipeline.fit(samples) pipeline = pipeline.fit(samples)
tags = pipeline.dask_tags() tags = pipeline.dask_tags()
...@@ -331,7 +331,7 @@ def test_checkpoint_fit_transform_pipeline_with_dask_non_pickle(): ...@@ -331,7 +331,7 @@ def test_checkpoint_fit_transform_pipeline_with_dask_non_pickle():
pipeline = Pipeline([fitter, transformer]) pipeline = Pipeline([fitter, transformer])
if dask_enabled: if dask_enabled:
dask_client = _get_local_client() dask_client = _get_local_client()
pipeline = dask_it(pipeline) pipeline = estimator_dask_it(pipeline)
pipeline = pipeline.fit(samples) pipeline = pipeline.fit(samples)
transformed_samples = pipeline.transform(samples_transform).compute( transformed_samples = pipeline.transform(samples_transform).compute(
scheduler=dask_client scheduler=dask_client
...@@ -348,10 +348,10 @@ def test_checkpoint_fit_transform_pipeline_with_dask_non_pickle(): ...@@ -348,10 +348,10 @@ def test_checkpoint_fit_transform_pipeline_with_dask_non_pickle():
def test_dask_checkpoint_transform_pipeline(): def test_dask_checkpoint_transform_pipeline():
X = np.ones(shape=(10, 2), dtype=int) X = np.ones(shape=(10, 2), dtype=int)
samples_transform = [Sample(data, key=str(i)) for i, data in enumerate(X)] samples_transform = [Sample(data, key=str(i)) for i, data in enumerate(X)]
with tempfile.TemporaryDirectory() as d: with tempfile.TemporaryDirectory() as d:
bag_transformer = DaskBagMixin() bag_transformer = DaskBagMixin()
estimator = dask_it(_build_transformer(d, 0), transform_tag="CPU") estimator = estimator_dask_it(_build_transformer(d, 0), transform_tag="CPU")
X_tr = estimator.transform(bag_transformer.transform(samples_transform)) X_tr = estimator.transform(bag_transformer.transform(samples_transform))
assert len(estimator.dask_tags()) == 1 assert len(estimator.dask_tags()) == 1
assert len(X_tr.compute(scheduler="single-threaded")) == 10 assert len(X_tr.compute(scheduler="single-threaded")) == 10
from .linearize import Linearize, SampleLinearize, CheckpointSampleLinearize
from .pca import CheckpointSamplePCA, SamplePCA
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.pipelines.mixins import CheckpointMixin, SampleMixin
from sklearn.preprocessing import FunctionTransformer
import numpy as np
def linearize(X):
X = np.asarray(X)
return np.reshape(X, (X.shape[0], -1))
class Linearize(FunctionTransformer):
"""Extracts features by simply concatenating all elements of the data into one long vector.
"""
def __init__(self, **kwargs):
super().__init__(func=linearize, **kwargs)
class SampleLinearize(SampleMixin, Linearize):
pass
class CheckpointSampleLinearize(CheckpointMixin, SampleMixin, Linearize):
pass
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
from bob.pipelines.mixins import CheckpointMixin, SampleMixin
from sklearn.decomposition import PCA
class SamplePCA(SampleMixin, PCA):
"""
Enables SAMPLE handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
"""
pass
class CheckpointSamplePCA(CheckpointMixin, SampleMixin, PCA):
"""
Enables SAMPLE and CHECKPOINTIN handling for https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
"""
pass
#!/usr/bin/env python #!/usr/bin/env python
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
def is_picklable(obj): def is_picklable(obj):
""" """
Test if an object is picklable or not Test if an object is picklable or not
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment