mixins.py 16.8 KB
Newer Older
1
2
# vim: set fileencoding=utf-8 :

3
from .sample import Sample, DelayedSample, SampleSet
4
import os
5
import types
6
import cloudpickle
7
import functools
8
9
import bob.io.base
from sklearn.preprocessing import FunctionTransformer
10
from sklearn.base import TransformerMixin, BaseEstimator
11
from sklearn.pipeline import Pipeline
12
13
from dask import delayed
import dask.bag
14
import os
15

16
17
18
19
20
21
22
23

def estimator_dask_it(
    o,
    fit_tag=None,
    transform_tag=None,
    npartitions=None,
    mix_for_each_step_in_pipelines=True,
):
24
    """
25
    Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` with
26
    :py:class`DaskEstimatorMixin`
27
28
29
30
31
32
33
34
35
36
37

    Parameters
    ----------

      o: :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base`
        Any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` to be dask mixed

      fit_tag: list(tuple()) or "str"
         Tag the `fit` method. This is useful to tag dask tasks to run in specific workers https://distributed.dask.org/en/latest/resources.html
         If `o` is :py:class:`sklearn.pipeline.Pipeline`, this parameter should contain a list of tuples
         containing the pipeline.step index and the `str` tag for `fit`.
38
         If `o` is :py:class:`sklearn.estimator.Base`, this parameter should contain just the tag for `fit`
39
40
41
42
43
44


      transform_tag: list(tuple()) or "str"
         Tag the `fit` method. This is useful to tag dask tasks to run in specific workers https://distributed.dask.org/en/latest/resources.html
         If `o` is :py:class:`sklearn.pipeline.Pipeline`, this parameter should contain a list of tuples
         containing the pipeline.step index and the `str` tag for `transform`.
45
         If `o` is :py:class:`sklearn.estimator.Base`, this parameter should contain just the tag for `transform`
46
47
48
49
50
51
52


    Examples
    --------

      Vanilla example

53
      >>> pipeline = estimator_dask_it(pipeline) # Take some pipeline and make the methods `fit`and `transform` run over dask
54
55
56
57
58
59
      >>> pipeline.fit(samples).compute()


      In this example we will "mark" the fit method with a particular tag
      Hence, we can set the `dask.delayed.compute` method to place some
      delayeds to be executed in particular resources
60

61
      >>> pipeline = estimator_dask_it(pipeline, fit_tag=[(1, "GPU")]) # Take some pipeline and make the methods `fit`and `transform` run over dask
62
63
64
65
      >>> fit = pipeline.fit(samples)
      >>> fit.compute(resources=pipeline.dask_tags())

      Taging estimator
66
      >>> estimator = estimator_dask_it(estimator)
67
68
69
      >>> transf = estimator.transform(samples)
      >>> transf.compute(resources=estimator.dask_tags())

70
    """
71

72
    def _fetch_resource_tape(self):
73
74
75
        """
        Get all the resources take
        """
76
77
        resource_tags = dict()
        if isinstance(self, Pipeline):
78
            for i in range(1, len(self.steps)):
79
80
81
                resource_tags.update(o[i].resource_tags)
        else:
            resource_tags.update(self.resource_tags)
82

83
        return resource_tags
84

85
    if isinstance(o, Pipeline):
86
87
        # Adding a daskbag in the tail of the pipeline
        o.steps.insert(0, ("0", DaskBagMixin(npartitions=npartitions)))
88

89
    # Patching dask_resources
90
91
92
93
94
    dasked = mix_me_up(
        DaskEstimatorMixin,
        o,
        mix_for_each_step_in_pipelines=mix_for_each_step_in_pipelines,
    )
95

96
    # Tagging each element in a pipeline
97
98
99
100
101
102
103
    if isinstance(o, Pipeline):

        # Tagging each element for fitting and transforming
        if fit_tag is not None:
            for t in fit_tag:
                o.steps[t[0]][1].fit_tag = t[1]

104
        if transform_tag is not None:
105
106
107
108
109
110
            for t in transform_tag:
                o.steps[t[0]][1].transform_tag = t[1]
    else:
        dasked.fit_tag = fit_tag
        dasked.transform_tag = transform_tag

111
    # Bounding the method
112
    dasked.dask_tags = types.MethodType(_fetch_resource_tape, dasked)
113

114
    return dasked
115
116


117
def mix_me_up(bases, o, mix_for_each_step_in_pipelines=True):
118
    """
119
120
    Dynamically creates a new class from :any:`object` or :any:`class`.
    For instance, mix_me_up((A,B), class_c) is equal to `class ABC(A,B,C) pass:`
121

122
123
    Example
    -------
124

125
126
127
128
       >>> my_mixed_class = mix_me_up([MixInA, MixInB], OriginalClass)
       >>> mixed_object = my_mixed_class(*args)

    It's also possible to mix up an instance:
129
130
131
132

    Example
    -------

133
       >>> instance = OriginalClass()
134
       >>> mixed_object = mix_me_up([MixInA, MixInB], instance)
135
136
137
138

    It's also possible to mix up a :py:class:`sklearn.pipeline.Pipeline`.
    In this case, every estimator inside of :py:meth:`sklearn.pipeline.Pipeline.steps`
    will be mixed up
139
140
141
142


    Parameters
    ----------
143
      bases:  or :any:`tuple`
144
        Base classes to be mixed in
145

146
147
      o: :any:`class`, :any:`object` or :py:class:`sklearn.pipeline.Pipeline`
        Base element to be extended
148
149

    """
150

151
152
    def _mix(bases, o):
        bases = bases if isinstance(bases, tuple) else tuple([bases])
153
        class_name = "".join([c.__name__ for c in bases])
154
        if isinstance(o, types.ClassType):
155
            # If it's a class, just merge them
156
            class_name += o.__name__
157
158
            new_type = type(class_name, bases + tuple([o]), {})
        else:
159
            # If it's an object, creates a new class and copy the state of the current object
160
            class_name += o.__class__.__name__
161
            new_type = type(class_name, bases + tuple([o.__class__]), o.__dict__)()
162
163
164
165
166
            # new_type.__dict__ is made in the descending order of the classes
            # so the values of o.__dict__ are overwritten by the lower ones
            # here we are copying them back
            for k in o.__dict__:
                new_type.__dict__[k] = o.__dict__[k]
167
        return new_type
168

169
170
    # If it is a scikit pipeline, mixIN everything inside of
    # Pipeline.steps
171
    if isinstance(o, Pipeline) and mix_for_each_step_in_pipelines:
172
173
        # mixing all pipelines
        for i in range(len(o.steps)):
174
175
176
            # checking if it's not the bag transformer
            if isinstance(o.steps[i][1], DaskBagMixin):
                continue
177
            o.steps[i] = (str(i), _mix(bases, o.steps[i][1]))
178
179
        return o
    else:
180
181
        return _mix(bases, o)

182

183
184
185
186
187
188
def _is_estimator_stateless(estimator):
    if not hasattr(estimator, "_get_tags"):
        return False
    return estimator._get_tags()["stateless"]


189
190
191
192
193
def _make_kwargs_from_samples(samples, arg_attr_list):
    kwargs = {arg: [getattr(s, attr) for s in samples] for arg, attr in arg_attr_list}
    return kwargs


194
class SampleMixin(BaseEstimator):
195
196
    """Mixin class to make scikit-learn estimators work in :any:`Sample`-based
    pipelines.
197
    Do not use this class except for scikit-learn estimators.
198

199
    .. todo::
200

201
202
        Also implement ``predict``, ``predict_proba``, and ``score``. See:
        https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects
203

204
205
206
207
208
209
210
211
212
213
214
215
    Attributes
    ----------
    fit_extra_arguments : [tuple], optional
        Use this option if you want to pass extra arguments to the fit method of the
        mixed instance. The format is a list of two value tuples. The first value in
        tuples is the name of the argument that fit accepts, like ``y``, and the second
        value is the name of the attribute that samples carry. For example, if you are
        passing samples to the fit method and want to pass ``subject`` attributes of
        samples as the ``y`` argument to the fit method, you can provide ``[("y",
        "subject")]`` as the value for this attribute.
    transform_extra_arguments : [tuple], optional
        Similar to ``fit_extra_arguments`` but for the transform method.
216
    """
217
218
219
220

    def __init__(
        self, transform_extra_arguments=None, fit_extra_arguments=None, **kwargs
    ):
221
        super().__init__(**kwargs)
222
223
        self.transform_extra_arguments = transform_extra_arguments or tuple()
        self.fit_extra_arguments = fit_extra_arguments or tuple()
224

225
    def transform(self, samples):
226

227
        # Transform either samples or samplesets
228
        if isinstance(samples[0], Sample) or isinstance(samples[0], DelayedSample):
229
            kwargs = _make_kwargs_from_samples(samples, self.transform_extra_arguments)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
230
231
232
233
            features = []
            for s in samples:
                features.append(super().transform(s.data, **kwargs))

234
235
236
            new_samples = [Sample(data, parent=s) for data, s in zip(features, samples)]
            return new_samples
        elif isinstance(samples[0], SampleSet):
237
238
239
            return [
                SampleSet(self.transform(sset.samples), parent=sset) for sset in samples
            ]
240
241
242
        else:
            raise ValueError("Type for sample not supported %s" % type(samples))

243
    def fit(self, samples, y=None):
244

245
246
247
        # See: https://scikit-learn.org/stable/developers/develop.html
        # if the estimator does not require fit or is stateless don't call fit
        tags = self._get_tags()
248
        if tags["stateless"] or ("requires_fit" in tags and not tags["requires_fit"]):
249
            return self
250

251
252
253
254
        # if the estimator needs to be fitted.
        kwargs = _make_kwargs_from_samples(samples, self.fit_extra_arguments)
        X = [s.data for s in samples]
        return super().fit(X, **kwargs)
255
256


257
258
259
class CheckpointMixin:
    """Mixin class that allows :any:`Sample`-based estimators save their results into
    disk."""
260

261
262
263
264
265
266
267
268
269
    def __init__(
        self,
        model_path=None,
        features_dir=None,
        extension=".h5",
        save_func=None,
        load_func=None,
        **kwargs
    ):
270
271
272
273
        super().__init__(**kwargs)
        self.model_path = model_path
        self.features_dir = features_dir
        self.extension = extension
274
275
        self.save_func = save_func or bob.io.base.save
        self.load_func = load_func or bob.io.base.load
276

277
    def transform_one_sample(self, sample):
278

279
280
281
282
283
284
285
286
        # Check if the sample is already processed.
        path = self.make_path(sample)
        if path is None or not os.path.isfile(path):
            new_sample = super().transform([sample])[0]
            # save the new sample
            self.save(new_sample)
        else:
            new_sample = self.load(path)
287

288
        return new_sample
289

290
291
292
293
    def transform_one_sample_set(self, sample_set):
        samples = [self.transform_one_sample(s) for s in sample_set.samples]
        return SampleSet(samples, parent=sample_set)

294
    def transform(self, samples):
295
296
297
298
299
300
301
302
303
304
        if not isinstance(samples, list):
            raise ValueError("It's expected a list, not %s" % type(samples))

        if isinstance(samples[0], Sample) or isinstance(samples[0], DelayedSample):
            return [self.transform_one_sample(s) for s in samples]
        elif isinstance(samples[0], SampleSet):
            return [self.transform_one_sample_set(s) for s in samples]
        else:
            raise ValueError("Type not allowed %s" % type(samples[0]))

305
    def fit(self, samples, y=None):
306

307
        if self.model_path is not None and os.path.isfile(self.model_path):
308
            return self.load_model()
309

310
        super().fit(samples, y=y)
311
        return self.save_model()
312

313
    def fit_transform(self, samples, y=None):
314

315
        return self.fit(samples, y=y).transform(samples)
316

317
318
    def make_path(self, sample):
        if self.features_dir is None:
319
320
            raise ValueError("`features_dir` is not in %s" % CheckpointMixin.__name__)

321
        return os.path.join(self.features_dir, str(sample.key) + self.extension)
322

323
324
325
326
    def recover_key_from_path(self, path):
        key = path.replace(os.path.abspath(self.features_dir), "")
        key = path[: -len(self.extension)]
        return key
327

328
    def save(self, sample):
329
330
        if isinstance(sample, Sample):
            path = self.make_path(sample)
331
            os.makedirs(os.path.dirname(path), exist_ok=True)
332
            return self.save_func(sample.data, path)
333
334
335
        elif isinstance(sample, SampleSet):
            for s in sample.samples:
                path = self.make_path(s)
336
                os.makedirs(os.path.dirname(path), exist_ok=True)
337
                return self.save_func(s.data, path)
338
        else:
339
            raise ValueError("Type for sample not supported %s" % type(sample))
340

341
342
343
344
345
346
    def load(self, path):
        key = self.recover_key_from_path(path)
        # because we are checkpointing, we return a DelayedSample
        # instead of a normal (preloaded) sample. This allows the next
        # phase to avoid loading it would it be unnecessary (e.g. next
        # phase is already check-pointed)
347
        return DelayedSample(functools.partial(self.load_func, path), key=key)
348

349
350
351
352
    def load_model(self):
        if _is_estimator_stateless(self):
            return self
        with open(self.model_path, "rb") as f:
353
354
355
            model = cloudpickle.load(f)
            self.__dict__.update(model.__dict__)
            return model
356

357
    def save_model(self):
358
        if _is_estimator_stateless(self) or self.model_path is None:
359
360
361
            return self
        os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
        with open(self.model_path, "wb") as f:
362
            cloudpickle.dump(self, f)
363
        return self
364
365


366
class SampleFunctionTransformer(SampleMixin, FunctionTransformer):
367
368
369
370
    """Mixin class that transforms Scikit learn FunctionTransformer (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html)
    work with :any:`Sample`-based pipelines.
    """

371
    pass
372

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
373

374
375
376
class CheckpointSampleFunctionTransformer(
    CheckpointMixin, SampleMixin, FunctionTransformer
):
377
378
379
380
381
    """Mixin class that transforms Scikit learn FunctionTransformer (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html)
    work with :any:`Sample`-based pipelines.

    Furthermore, it makes it checkpointable
    """
382

383
    pass
384

385

386
class NonPicklableMixin:
387
388
389
390
    """Class that wraps estimators that are not picklable

    Example
    -------
391
392
        >>> from bob.pipelines.processor import NonPicklableMixin
        >>> wrapper = NonPicklableMixin(my_non_picklable_class_callable)
393
394
395

    Example
    -------
396
        >>> from bob.pipelines.processor import NonPicklableMixin
397
        >>> import functools
398
        >>> wrapper = NonPicklableMixin(functools.partial(MyNonPicklableClass, arg1, arg2))
399
400
401
402
403
404
405
406
407


    Parameters
    ----------
      callable: callable
         Calleble function that instantiates the scikit estimator

    """

408
    def __init__(self, callable=None):
409
410
411
        self.callable = callable
        self.instance = None

412
    def fit(self, X, y=None, **fit_params):
413
414
415
416
417
418
        # Instantiates and do the "real" fit
        if self.instance is None:
            self.instance = self.callable()
        return self.instance.fit(X, y=y, **fit_params)

    def transform(self, X):
419

420
421
422
423
424
425
        # Instantiates and do the "real" transform
        if self.instance is None:
            self.instance = self.callable()
        return self.instance.transform(X)


426
427
class DaskEstimatorMixin:
    """Wraps Scikit estimators into Daskable objects
428
429
430
431
432

    Parameters
    ----------

       fit_resource: str
433
           Mark the delayed(self.fit) with this value. This can be used in
434
435
436
437
438
           a future delayed(self.fit).compute(resources=resource_tape) so
           dask scheduler can place this task in a particular resource
           (e.g GPU)

       transform_resource: str
439
           Mark the delayed(self.transform) with this value. This can be used in
440
441
442
443
           a future delayed(self.transform).compute(resources=resource_tape) so
           dask scheduler can place this task in a particular resource
           (e.g GPU)

444
445
    """

446
    def __init__(self, fit_tag=None, transform_tag=None, **kwargs):
447
448
        super().__init__(**kwargs)
        self._dask_state = self
449
450
451
        self.resource_tags = dict()
        self.fit_tag = fit_tag
        self.transform_tag = transform_tag
452

453
    def fit(self, X, y=None, **fit_params):
454
        self._dask_state = delayed(super().fit)(X, y, **fit_params)
455
456
        if self.fit_tag is not None:
            self.resource_tags[self._dask_state] = self.fit_tag
457

458
        return self
459

460
    def transform(self, X):
461
        def _transf(X_line, dask_state):
462
            return super(DaskEstimatorMixin, dask_state).transform(X_line)
463

464
        map_partitions = X.map_partitions(_transf, self._dask_state)
465
466
        if self.transform_tag is not None:
            self.resource_tags[map_partitions] = self.transform_tag
467
468

        return map_partitions
469
470


471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
class DaskBagMixin(TransformerMixin):
    """Transform an arbitrary iterator into a :py:class:`dask.bag`


    Paramters
    ---------

      npartitions: int
        Number of partitions used it :py:meth:`dask.bag.npartitions`


    Example
    -------

    >>> transformer = DaskBagMixin()
    >>> dask_bag = transformer.transform([1,2,3])
    >>> dask_bag.map_partitions.....

    """

    def __init__(self, npartitions=None, **kwargs):
        super().__init__(**kwargs)
        self.npartitions = npartitions

    def fit(self, X, y=None, **kwargs):
        return self

    def transform(self, X, **kwargs):
        return dask.bag.from_sequence(X, npartitions=self.npartitions)