FaceNet.py 7.12 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1
2
3
4
5
6
7
8
from __future__ import division
import os
import re
import logging
import numpy
import tensorflow as tf
from bob.ip.color import gray_to_rgb
from bob.io.image import to_matplotlib
9
from bob.extension import rc
10
11
import bob.extension.download
import bob.io.base
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
12
13
14

logger = logging.getLogger(__name__)

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
15
16
FACENET_MODELPATH_KEY = "bob.ip.tensorflow_extractor.facenet_modelpath"

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
17
18
19
20
21
22
23
24
25
26
27
28

def prewhiten(img):
    mean = numpy.mean(img)
    std = numpy.std(img)
    std_adj = numpy.maximum(std, 1.0 / numpy.sqrt(img.size))
    y = numpy.multiply(numpy.subtract(img, mean), 1 / std_adj)
    return y


def get_model_filenames(model_dir):
    # code from https://github.com/davidsandberg/facenet
    files = os.listdir(model_dir)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
29
    meta_files = [s for s in files if s.endswith(".meta")]
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
30
    if len(meta_files) == 0:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
31
        raise ValueError("No meta file found in the model directory (%s)" % model_dir)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
32
33
    elif len(meta_files) > 1:
        raise ValueError(
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
34
35
36
            "There should not be more than one meta file in the model "
            "directory (%s)" % model_dir
        )
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
37
38
39
    meta_file = meta_files[0]
    max_step = -1
    for f in files:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
40
        step_str = re.match(r"(^model-[\w\- ]+.ckpt-(\d+))", f)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
41
42
43
44
45
46
47
        if step_str is not None and len(step_str.groups()) >= 2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file

48

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
49
50
class FaceNet(object):
    """Wrapper for the free FaceNet variant:
51
52
53
54
55
56
57
58
    https://github.com/davidsandberg/facenet

    To use this class as a bob.bio.base extractor::

        from bob.bio.base.extractor import Extractor
        class FaceNetExtractor(FaceNet, Extractor):
            pass
        extractor = FaceNetExtractor()
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    And for a preprocessor you can use::

        from bob.bio.face.preprocessor import FaceCrop
        # This is the size of the image that this model expects
        CROPPED_IMAGE_HEIGHT = 160
        CROPPED_IMAGE_WIDTH = 160
        # eye positions for frontal images
        RIGHT_EYE_POS = (46, 53)
        LEFT_EYE_POS = (46, 107)
        # Crops the face using eye annotations
        preprocessor = FaceCrop(
            cropped_image_size=(CROPPED_IMAGE_HEIGHT, CROPPED_IMAGE_WIDTH),
            cropped_positions={'leye': LEFT_EYE_POS, 'reye': RIGHT_EYE_POS},
            color_channel='rgb'
        )

76
    """
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
77

78
    def __init__(
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
79
80
81
82
83
84
        self,
        model_path=rc[FACENET_MODELPATH_KEY],
        image_size=160,
        layer_name="embeddings:0",
        **kwargs
    ):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
85
86
87
        super(FaceNet, self).__init__()
        self.model_path = model_path
        self.image_size = image_size
88
89
        self.layer_name = layer_name
        self.loaded = False
90
91
92
        self._clean_unpicklables()

    def _clean_unpicklables(self):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
93
94
        self.session = None
        self.embeddings = None
95
96
97
        self.graph = None
        self.images_placeholder = None
        self.phase_train_placeholder = None
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
98
99
100
101
102
103
104
105
106
107
108
109

    def _check_feature(self, img):
        img = numpy.ascontiguousarray(img)
        if img.ndim == 2:
            img = gray_to_rgb(img)
        assert img.shape[-1] == self.image_size
        assert img.shape[-2] == self.image_size
        img = to_matplotlib(img)
        img = prewhiten(img)
        return img[None, ...]

    def load_model(self):
110
111
112
113
114
115
        tf.compat.v1.reset_default_graph()

        session_conf = tf.ConfigProto(
              intra_op_parallelism_threads=1,
              inter_op_parallelism_threads=1)        

116
        self.graph = tf.Graph()
117
        self.session = tf.compat.v1.Session(graph=self.graph, config=session_conf)
118

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
119
120
121
        if self.model_path is None:
            self.model_path = self.get_modelpath()
        if not os.path.exists(self.model_path):
122
            bob.io.base.create_directories_safe(FaceNet.get_modelpath())
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
123
            zip_file = os.path.join(FaceNet.get_modelpath(), "20170512-110547.zip")
124
            urls = [
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
125
126
                # This link only works in Idiap CI to save bandwidth.
                "http://www.idiap.ch/private/wheels/gitlab/"
127
                "facenet_model2_20170512-110547.zip",
128
129
130
                # this link to dropbox would work for everybody
                "https://www.dropbox.com/s/"
                "k7bhxe58q7d48g7/facenet_model2_20170512-110547.zip?dl=1",
131
            ]
132
            bob.extension.download.download_and_unzip(urls, zip_file)
133

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
134
135
        # code from https://github.com/davidsandberg/facenet
        model_exp = os.path.expanduser(self.model_path)
136
        with self.graph.as_default():
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
137
138
139
140
            if os.path.isfile(model_exp):
                logger.info("Model filename: %s" % model_exp)
                with tf.compat.v1.gfile.FastGFile(model_exp, "rb") as f:
                    graph_def = tf.compat.v1.GraphDef()
141
                    graph_def.ParseFromString(f.read())
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
142
                    tf.import_graph_def(graph_def, name="")
143
            else:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
144
                logger.info("Model directory: %s" % model_exp)
145
146
                meta_file, ckpt_file = get_model_filenames(model_exp)

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
147
148
                logger.info("Metagraph file: %s" % meta_file)
                logger.info("Checkpoint file: %s" % ckpt_file)
149

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
150
151
152
153
                saver = tf.compat.v1.train.import_meta_graph(
                    os.path.join(model_exp, meta_file)
                )
                saver.restore(self.session, os.path.join(model_exp, ckpt_file))
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
154
155
        # Get input and output tensors
        self.images_placeholder = self.graph.get_tensor_by_name("input:0")
156
        self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
157
        self.phase_train_placeholder = self.graph.get_tensor_by_name("phase_train:0")
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
158
        logger.info("Successfully loaded the model.")
159
        self.loaded = True
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
160
161

    def __call__(self, img):
162
163
164
165
166
167
168
169
170
        images = self._check_feature(img)
        if not self.loaded:
            self.load_model()

        feed_dict = {
            self.images_placeholder: images,
            self.phase_train_placeholder: False,
        }
        features = self.session.run(self.embeddings, feed_dict=feed_dict)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
171
172
173
174
        return features.flatten()

    @staticmethod
    def get_modelpath():
175
        """
176
        Get default model path.
177
178

        First we try the to search this path via Global Configuration System.
179
180
        If we can not find it, we set the path in the directory
        `<project>/data`
181
        """
182

183
        # Priority to the RC path
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
184
        model_path = rc["bob.ip.tensorflow_extractor.facenet_modelpath"]
185
186
187

        if model_path is None:
            import pkg_resources
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
188

189
            model_path = pkg_resources.resource_filename(
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
190
191
                __name__, "data/FaceNet/20170512-110547"
            )
192
193

        return model_path
194
195
196

    def __setstate__(self, d):
        # Handling unpicklable objects
197
        self.__dict__ = d        
198
199
200

    def __getstate__(self):
        # Handling unpicklable objects
201
202
203
204
205
206
207
208
209
        self.loaded = False
        d = self.__dict__
        d.pop("session") if "session" in d else None
        d.pop("embeddings") if "embeddings" in d else None
        d.pop("graph") if "graph" in d else None
        d.pop("images_placeholder") if "images_placeholder" in d else None
        d.pop("phase_train_placeholder") if "phase_train_placeholder" in d else None

        return d