FaceNet.py 6.49 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1
2
3
4
5
6
7
8
from __future__ import division
import os
import re
import logging
import numpy
import tensorflow as tf
from bob.ip.color import gray_to_rgb
from bob.io.image import to_matplotlib
9
from bob.extension import rc
10
11
import bob.extension.download
import bob.io.base
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

logger = logging.getLogger(__name__)


def prewhiten(img):
    mean = numpy.mean(img)
    std = numpy.std(img)
    std_adj = numpy.maximum(std, 1.0 / numpy.sqrt(img.size))
    y = numpy.multiply(numpy.subtract(img, mean), 1 / std_adj)
    return y


def get_model_filenames(model_dir):
    # code from https://github.com/davidsandberg/facenet
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files) == 0:
        raise ValueError(
            'No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files) > 1:
        raise ValueError(
            'There should not be more than one meta file in the model '
            'directory (%s)' % model_dir)
    meta_file = meta_files[0]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups()) >= 2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file


class FaceNet(object):
    """Wrapper for the free FaceNet variant:
49
50
51
52
53
54
55
56
    https://github.com/davidsandberg/facenet

    To use this class as a bob.bio.base extractor::

        from bob.bio.base.extractor import Extractor
        class FaceNetExtractor(FaceNet, Extractor):
            pass
        extractor = FaceNetExtractor()
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    And for a preprocessor you can use::

        from bob.bio.face.preprocessor import FaceCrop
        # This is the size of the image that this model expects
        CROPPED_IMAGE_HEIGHT = 160
        CROPPED_IMAGE_WIDTH = 160
        # eye positions for frontal images
        RIGHT_EYE_POS = (46, 53)
        LEFT_EYE_POS = (46, 107)
        # Crops the face using eye annotations
        preprocessor = FaceCrop(
            cropped_image_size=(CROPPED_IMAGE_HEIGHT, CROPPED_IMAGE_WIDTH),
            cropped_positions={'leye': LEFT_EYE_POS, 'reye': RIGHT_EYE_POS},
            color_channel='rgb'
        )

74
    """
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
75

76
77
78
79
80
81
    def __init__(
            self,
            model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"],
            image_size=160,
            layer_name='embeddings:0',
            **kwargs):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
82
83
84
85
86
        super(FaceNet, self).__init__()
        self.model_path = model_path
        self.image_size = image_size
        self.session = None
        self.embeddings = None
87
        self.layer_name = layer_name
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    def _check_feature(self, img):
        img = numpy.ascontiguousarray(img)
        if img.ndim == 2:
            img = gray_to_rgb(img)
        assert img.shape[-1] == self.image_size
        assert img.shape[-2] == self.image_size
        img = to_matplotlib(img)
        img = prewhiten(img)
        return img[None, ...]

    def load_model(self):
        if self.model_path is None:
            self.model_path = self.get_modelpath()
        if not os.path.exists(self.model_path):
103
            bob.io.base.create_directories_safe(FaceNet.get_modelpath())
104
105
            zip_file = os.path.join(FaceNet.get_modelpath(),
                                    "20170512-110547.zip")
106
107
            urls = [
                # This is a private link at Idiap to save bandwidth.
108
                "http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
109
                "facenet_model2_20170512-110547.zip",
110
111
                # this link to dropbox would work for everybody
                # previous link to gogle drive would require cookies
112
113
                "https://www.dropbox.com/s/"
                "k7bhxe58q7d48g7/facenet_model2_20170512-110547.zip?dl=1",
114
            ]
115
            bob.extension.download.download_and_unzip(urls, zip_file)
116

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
117
118
        # code from https://github.com/davidsandberg/facenet
        model_exp = os.path.expanduser(self.model_path)
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        with self.graph.as_default():
            if (os.path.isfile(model_exp)):
                logger.info('Model filename: %s' % model_exp)
                with tf.gfile.FastGFile(model_exp, 'rb') as f:
                    graph_def = tf.GraphDef()
                    graph_def.ParseFromString(f.read())
                    tf.import_graph_def(graph_def, name='')
            else:
                logger.info('Model directory: %s' % model_exp)
                meta_file, ckpt_file = get_model_filenames(model_exp)

                logger.info('Metagraph file: %s' % meta_file)
                logger.info('Checkpoint file: %s' % ckpt_file)

                saver = tf.train.import_meta_graph(
                    os.path.join(model_exp, meta_file))
                saver.restore(self.session,
                              os.path.join(model_exp, ckpt_file))
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
137
138
        # Get input and output tensors
        self.images_placeholder = self.graph.get_tensor_by_name("input:0")
139
        self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
140
141
142
143
144
145
146
        self.phase_train_placeholder = self.graph.get_tensor_by_name(
            "phase_train:0")
        logger.info("Successfully loaded the model.")

    def __call__(self, img):
        images = self._check_feature(img)
        if self.session is None:
147
148
            self.graph = tf.Graph()
            self.session = tf.Session(graph=self.graph)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
149
150
151
152
153
154
155
156
        if self.embeddings is None:
            self.load_model()
        feed_dict = {self.images_placeholder: images,
                     self.phase_train_placeholder: False}
        features = self.session.run(
            self.embeddings, feed_dict=feed_dict)
        return features.flatten()

157
158
    @staticmethod
    def get_rcvariable():
159
160
        """
        Variable name used in the Bob Global Configuration System
161
        https://www.idiap.ch/software/bob/docs/bob/bob.extension/stable/rc.html
162
        """
163
164
        return "bob.ip.tensorflow_extractor.facenet_modelpath"

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
165
166
    @staticmethod
    def get_modelpath():
167
        """
168
        Get default model path.
169
170

        First we try the to search this path via Global Configuration System.
171
172
        If we can not find it, we set the path in the directory
        `<project>/data`
173
        """
174

175
176
177
178
179
        # Priority to the RC path
        model_path = rc[FaceNet.get_rcvariable()]

        if model_path is None:
            import pkg_resources
180
181
            model_path = pkg_resources.resource_filename(
                __name__, 'data/FaceNet/20170512-110547')
182
183

        return model_path