algorithm.py 2.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

from sklearn.base import TransformerMixin, BaseEstimator
from bob.bio.base.algorithm import Algorithm
from bob.pipelines.utils import is_picklable
from . import split_X_by_y
import os

10

11
class AlgorithmTransformer(TransformerMixin, BaseEstimator):
12
    """Class that wraps :py:class:`bob.bio.base.algorithm.Algorithm`
13

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
14
    :any:`AlgorithmTransformer.fit` maps to :py:meth:`bob.bio.base.algorithm.Algorithm.train_projector`
15

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
16
    :any:`AlgorithmTransformer.transform` maps :py:meth:`bob.bio.base.algorithm.Algorithm.project`
17
18
19
20

    Example
    -------

21
        Wrapping LDA algorithm with functools
22
        >>> from bob.bio.base.pipelines.vanilla_biometrics import AlgorithmTransformer
23
        >>> from bob.bio.base.algorithm import LDA
24
        >>> transformer = AlgorithmTransformer(LDA(use_pinv=True, pca_subspace_dimension=0.90)
25
26
27
28


    Parameters
    ----------
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
29
30
    instance: object
        An instance of bob.bio.base.algorithm.Algorithm
31
32
33
34

    """

    def __init__(
35
        self, instance, projector_file=None, **kwargs,
36
37
    ):

38
        if not isinstance(instance, Algorithm):
39
            raise ValueError(
40
                "`instance` should be an instance of `bob.bio.base.extractor.Algorithm`"
41
42
            )

43
        if instance.requires_projector_training and (
44
45
46
            projector_file is None or projector_file == ""
        ):
            raise ValueError(
47
                f"`projector_file` needs to be set if extractor {instance} requires training"
48
49
            )

50
51
        if not is_picklable(instance):
            raise ValueError(f"{instance} needs to be picklable")
52

53
        self.instance = instance
54
55
56
57
        self.projector_file = projector_file
        super().__init__(**kwargs)

    def fit(self, X, y=None):
58
        if not self.instance.requires_projector_training:
59
60
            return self
        training_data = X
61
        if self.instance.split_training_features_by_client:
62
63
64
            training_data = split_X_by_y(X, y)

        os.makedirs(os.path.dirname(self.projector_file), exist_ok=True)
65
        self.instance.train_projector(training_data, self.projector_file)
66
67
        return self

68
    def transform(self, X, metadata=None):
69
        if metadata is None:
70
            return [self.instance.project(data) for data in X]
71
72
        else:
            return [
73
                self.instance.project(data, metadata)
74
75
76
77
                for data, metadata in zip(X, metadata)
            ]

    def _more_tags(self):
78
        return {
79
80
            "stateless": not self.instance.requires_projector_training,
            "requires_fit": self.instance.requires_projector_training,
81
82
            "bob_features_save_fn": self.instance.write_feature,
            "bob_features_load_fn": self.instance.read_feature,
83
        }