Commit 53e0a0e4 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add facenet extractor

parent d773cb99
......@@ -11,5 +11,4 @@ src
develop-eggs
sphinx
dist
bob/ip/caffe_extractor/data/face_verification_experiment-master/
bob/ip/caffe_extractor/data/vgg_face_caffe/
bob/ip/tensorflow_extractor/data/FaceNet/
from __future__ import division
import os
import re
import logging
import numpy
import tensorflow as tf
from bob.ip.color import gray_to_rgb
from bob.io.image import to_matplotlib
from . import download_file
logger = logging.getLogger(__name__)
def prewhiten(img):
mean = numpy.mean(img)
std = numpy.std(img)
std_adj = numpy.maximum(std, 1.0 / numpy.sqrt(img.size))
y = numpy.multiply(numpy.subtract(img, mean), 1 / std_adj)
return y
def get_model_filenames(model_dir):
# code from https://github.com/davidsandberg/facenet
files = os.listdir(model_dir)
meta_files = [s for s in files if s.endswith('.meta')]
if len(meta_files) == 0:
raise ValueError(
'No meta file found in the model directory (%s)' % model_dir)
elif len(meta_files) > 1:
raise ValueError(
'There should not be more than one meta file in the model '
'directory (%s)' % model_dir)
meta_file = meta_files[0]
max_step = -1
for f in files:
step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
if step_str is not None and len(step_str.groups()) >= 2:
step = int(step_str.groups()[1])
if step > max_step:
max_step = step
ckpt_file = step_str.groups()[0]
return meta_file, ckpt_file
class FaceNet(object):
"""Wrapper for the free FaceNet variant:
https://github.com/davidsandberg/facenet"""
def __init__(self,
model_path=None,
image_size=160,
**kwargs):
super(FaceNet, self).__init__()
self.model_path = model_path
self.image_size = image_size
self.session = None
self.embeddings = None
def _check_feature(self, img):
img = numpy.ascontiguousarray(img)
if img.ndim == 2:
img = gray_to_rgb(img)
assert img.shape[-1] == self.image_size
assert img.shape[-2] == self.image_size
img = to_matplotlib(img)
img = prewhiten(img)
return img[None, ...]
def load_model(self):
if self.model_path is None:
self.model_path = self.get_modelpath()
if not os.path.exists(self.model_path):
self.download_model()
# code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)):
logger.info('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
logger.info('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
logger.info('Metagraph file: %s' % meta_file)
logger.info('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(
os.path.join(model_exp, meta_file))
saver.restore(tf.get_default_session(),
os.path.join(model_exp, ckpt_file))
# Get input and output tensors
self.images_placeholder = self.graph.get_tensor_by_name("input:0")
self.embeddings = self.graph.get_tensor_by_name("embeddings:0")
self.phase_train_placeholder = self.graph.get_tensor_by_name(
"phase_train:0")
logger.info("Successfully loaded the model.")
def __call__(self, img):
images = self._check_feature(img)
if self.session is None:
self.session = tf.InteractiveSession()
self.graph = tf.get_default_graph()
if self.embeddings is None:
self.load_model()
feed_dict = {self.images_placeholder: images,
self.phase_train_placeholder: False}
features = self.session.run(
self.embeddings, feed_dict=feed_dict)
return features.flatten()
def __del__(self):
tf.reset_default_graph()
@staticmethod
def get_modelpath():
import pkg_resources
return pkg_resources.resource_filename(__name__,
'data/FaceNet/20170512-110547')
@staticmethod
def download_model():
"""
Download and extract the FaceNet files in bob/ip/tensorflow_extractor
"""
import zipfile
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"facenet_model2_20170512-110547.zip",
# this works for everybody
"https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
]
for url in urls:
try:
logger.info(
"Downloading the FaceNet model from "
"{} ...".format(url))
download_file(url, zip_file)
break
except Exception as e:
logger.warning(
"Could not download from the %s url", url, exc_info=True)
else: # else is for the for loop
if not os.path.isfile(zip_file):
raise RuntimeError("Could not download the zip file.")
# Unzip
logger.info("Unziping in {0}".format(FaceNet.get_modelpath()))
with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(FaceNet.get_modelpath()))
# delete extra files
os.unlink(zip_file)
#!/usr/bin/env python
from .Extractor import Extractor
from .InceptionResNet_v1 import InceptionResNet_v1
def scratch_network(inputs, end_point="fc1", reuse = False):
def scratch_network(inputs, end_point="fc1", reuse=False):
import tensorflow as tf
slim = tf.contrib.slim
......@@ -13,76 +7,86 @@ def scratch_network(inputs, end_point="fc1", reuse = False):
# Creating a random network
initializer = tf.contrib.layers.xavier_initializer(seed=10)
end_points = dict()
graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
weights_initializer=initializer, reuse=reuse)
graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1,
scope='conv1', weights_initializer=initializer,
reuse=reuse)
end_points["conv1"] = graph
graph = slim.max_pool2d(graph, [4, 4], scope='pool1')
end_points["pool1"] = graph
graph = slim.flatten(graph, scope='flatten1')
end_points["flatten1"] = graph
graph = slim.fully_connected(graph, 10, activation_fn=None, scope='fc1',
weights_initializer=initializer, reuse=reuse)
end_points["fc1"] = graph
end_points["fc1"] = graph
return end_points[end_point]
def download_file(url, out_file):
"""Downloads a file from a given url
Parameters
----------
url : str
The url to download form.
out_file : str
Where to save the file.
"""
import sys
if sys.version_info[0] < 3:
# python2 technique for downloading a file
from urllib2 import urlopen
with open(out_file, 'wb') as f:
response = urlopen(url)
f.write(response.read())
else:
# python3 technique for downloading a file
from urllib.request import urlopen
from shutil import copyfileobj
with urlopen(url) as response:
with open(out_file, 'wb') as f:
copyfileobj(response, f)
"""Downloads a file from a given url
Parameters
----------
url : str
The url to download form.
out_file : str
Where to save the file.
"""
from bob.io.base import create_directories_safe
import os
create_directories_safe(os.path.dirname(out_file))
import sys
if sys.version_info[0] < 3:
# python2 technique for downloading a file
from urllib2 import urlopen
with open(out_file, 'wb') as f:
response = urlopen(url)
f.write(response.read())
else:
# python3 technique for downloading a file
from urllib.request import urlopen
from shutil import copyfileobj
with urlopen(url) as response:
with open(out_file, 'wb') as f:
copyfileobj(response, f)
def get_config():
"""Returns a string containing the configuration information.
"""
import bob.extension
return bob.extension.get_config(__name__)
"""Returns a string containing the configuration information.
"""
import bob.extension
return bob.extension.get_config(__name__)
from .Extractor import Extractor
from .FaceNet import FaceNet
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is shortened.
Parameters:
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is
shortened. Parameters:
*args: An iterable of objects to modify
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args:
obj.__module__ = __name__
for obj in args:
obj.__module__ = __name__
__appropriate__(
Extractor,
Extractor,
FaceNet,
)
# gets sphinx autodoc done right - don't remove it
......
......@@ -14,30 +14,39 @@ from . import scratch_network
def test_output():
# Loading MNIST model
filename = os.path.join( pkg_resources.resource_filename(__name__, 'data'), 'model.ckp')
# Loading MNIST model
filename = os.path.join(pkg_resources.resource_filename(
__name__, 'data'), 'model.ckp')
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
# Testing the last output
graph = scratch_network(inputs)
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 10)
del extractor
# Testing flatten
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = scratch_network(inputs, end_point="flatten1")
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 1690)
assert output.shape == (2, 1690)
del extractor
def test_facenet():
from bob.ip.tensorflow_extractor import FaceNet
extractor = FaceNet()
data = numpy.random.rand(3, 160, 160).astype("uint8")
output = extractor(data)
assert output.size == 128, output.shape
"""
def test_output_from_meta():
......@@ -48,19 +57,19 @@ def test_output_from_meta():
# Testing the last output
graph = scratch_network(inputs)
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 10)
del extractor
# Testing flatten
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = scratch_network(inputs, end_point="flatten1")
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 1690)
del extractor
del extractor
"""
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment