opencv.py 3.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch>
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>

import bob.bio.base
import numpy as np
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.utils import check_array
import os
from bob.extension.download import get_file


class OpenCVTransformer(TransformerMixin, BaseEstimator):
    """
    Base Transformer using the OpenCV interface.


    .. note::
       This class supports Caffe ``.caffemodel``, Tensorflow ``.pb``, Torch ``.t7`` ``.net``, Darknet ``.weights``, DLDT ``.bin``, and ONNX ``.onnx``


    Parameters
    ----------

    checkpoint_path: str
       Path containing the checkpoint

    config:
        Path containing some configuration file (e.g. .json, .prototxt)
    """

    def __init__(self, checkpoint_path=None, config=None, **kwargs):
        super().__init__(**kwargs)
        self.checkpoint_path = checkpoint_path
        self.config = config
        self.model = None

    def _load_model(self):
        import cv2

        net = cv2.dnn.readNet(self.checkpoint_path, self.config)
        self.model = net

    def transform(self, X):
        """__call__(image) -> feature

        Extracts the features from the given image.

        **Parameters:**

        X : 2D :py:class:`numpy.ndarray` (floats)
          The image to extract the features from.

        **Returns:**

        feature : 2D or 3D :py:class:`numpy.ndarray` (floats)
          The list of features extracted from the image.
        """

        import cv2

        if self.model is None:
            self._load_model()

        import ipdb

        ipdb.set_trace()

        img = np.array(X)
        img = img / 255

        self.model.setInput(img)

        return self.model.forward()

    def __getstate__(self):
        # Handling unpicklable objects

        d = self.__dict__.copy()
        d["model"] = None
        return d

    def _more_tags(self):
        return {"stateless": True, "requires_fit": False}


class VGG16_Oxford(OpenCVTransformer):
    """
    Original VGG16 model from the paper: https://www.robots.ox.ac.uk/~vgg/publications/2015/Parkhi15/parkhi15.pdf

    """

    def __init__(self):
        urls = [
            "https://www.robots.ox.ac.uk/~vgg/software/vgg_face/src/vgg_face_caffe.tar.gz",
            "http://bobconda.lab.idiap.ch/public-upload/data/bob/bob.bio.face/master/caffe/vgg_face_caffe.tar.gz",
        ]

        filename = get_file(
            "vgg_face_caffe.tar.gz",
            urls,
            cache_subdir="data/caffe/vgg_face_caffe",
            file_hash="ee707ac6e890bc148cb155adeaad12be",
            extract=True,
        )
        path = os.path.dirname(filename)
        config = os.path.join(path, "vgg_face_caffe", "VGG_FACE_deploy.prototxt")
        checkpoint_path = os.path.join(path, "vgg_face_caffe", "VGG_FACE.caffemodel")

        super(VGG16_Oxford, self).__init__(checkpoint_path, config)

    def _load_model(self):
        import cv2

        net = cv2.dnn.readNet(self.checkpoint_path, self.config)
        self.model = net