Commit a2758b49 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

bring back Extractor

parent f1345743
Pipeline #38023 passed with stage
in 17 minutes and 31 seconds
# see https://docs.python.org/3/library/pkgutil.html
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
# see https://docs.python.org/3/library/pkgutil.html
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Fri 17 Jun 2016 10:41:36 CEST
import tensorflow as tf
import os
from tensorflow.python import debug as tf_debug
class Extractor(object):
"""
Feature extractor using tensorflow
"""
def __init__(self, checkpoint_filename, input_tensor, graph, debug=False):
"""Loads the tensorflow model
Parameters
----------
checkpoint_filename: str
Path of your checkpoint. If the .meta file is providede the last checkpoint will be loaded.
model :
input_tensor: tf.Tensor used as a data entrypoint. It can be a **tf.placeholder**, the
result of **tf.train.string_input_producer**, etc
graph :
A tf.Tensor containing the operations to be executed
"""
self.input_tensor = input_tensor
self.graph = graph
# Initializing the variables of the current graph
self.session = tf.compat.v1.Session()
self.session.run(tf.compat.v1.global_variables_initializer())
# Loading the last checkpoint and overwriting the current variables
saver = tf.compat.v1.train.Saver()
if os.path.splitext(checkpoint_filename)[1] == ".meta":
saver.restore(
self.session,
tf.train.latest_checkpoint(os.path.dirname(checkpoint_filename)),
)
elif os.path.isdir(checkpoint_filename):
saver.restore(self.session, tf.train.latest_checkpoint(checkpoint_filename))
else:
saver.restore(self.session, checkpoint_filename)
# Activating the debug
if debug:
self.session = tf_debug.LocalCLIDebugWrapperSession(self.session)
def __del__(self):
tf.compat.v1.reset_default_graph()
def __call__(self, data):
"""
Forward the data with the loaded neural network
Parameters
----------
image : numpy.ndarray
Input Data
Returns
-------
numpy.ndarray
The features.
"""
return self.session.run(self.graph, feed_dict={self.input_tensor: data})
......@@ -12,7 +12,7 @@ import bob.io.base
logger = logging.getLogger(__name__)
FACENET_MODELPATH_KEY ="bob.ip.tensorflow_extractor.facenet_modelpath"
FACENET_MODELPATH_KEY = "bob.ip.tensorflow_extractor.facenet_modelpath"
def prewhiten(img):
......@@ -26,18 +26,18 @@ def prewhiten(img):
def get_model_filenames(model_dir):
# code from https://github.com/davidsandberg/facenet
files = os.listdir(model_dir)
meta_files = [s for s in files if s.endswith('.meta')]
meta_files = [s for s in files if s.endswith(".meta")]
if len(meta_files) == 0:
raise ValueError(
'No meta file found in the model directory (%s)' % model_dir)
raise ValueError("No meta file found in the model directory (%s)" % model_dir)
elif len(meta_files) > 1:
raise ValueError(
'There should not be more than one meta file in the model '
'directory (%s)' % model_dir)
"There should not be more than one meta file in the model "
"directory (%s)" % model_dir
)
meta_file = meta_files[0]
max_step = -1
for f in files:
step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
step_str = re.match(r"(^model-[\w\- ]+.ckpt-(\d+))", f)
if step_str is not None and len(step_str.groups()) >= 2:
step = int(step_str.groups()[1])
if step > max_step:
......@@ -76,11 +76,12 @@ class FaceNet(object):
"""
def __init__(
self,
model_path=rc[FACENET_MODELPATH_KEY],
image_size=160,
layer_name='embeddings:0',
**kwargs):
self,
model_path=rc[FACENET_MODELPATH_KEY],
image_size=160,
layer_name="embeddings:0",
**kwargs
):
super(FaceNet, self).__init__()
self.model_path = model_path
self.image_size = image_size
......@@ -103,8 +104,7 @@ class FaceNet(object):
self.model_path = self.get_modelpath()
if not os.path.exists(self.model_path):
bob.io.base.create_directories_safe(FaceNet.get_modelpath())
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
zip_file = os.path.join(FaceNet.get_modelpath(), "20170512-110547.zip")
urls = [
# This link only works in Idiap CI to save bandwidth.
"http://www.idiap.ch/private/wheels/gitlab/"
......@@ -118,28 +118,27 @@ class FaceNet(object):
# code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path)
with self.graph.as_default():
if (os.path.isfile(model_exp)):
logger.info('Model filename: %s' % model_exp)
with tf.compat.v1.gfile.FastGFile(model_exp, 'rb') as f:
if os.path.isfile(model_exp):
logger.info("Model filename: %s" % model_exp)
with tf.compat.v1.gfile.FastGFile(model_exp, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
tf.import_graph_def(graph_def, name="")
else:
logger.info('Model directory: %s' % model_exp)
logger.info("Model directory: %s" % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
logger.info('Metagraph file: %s' % meta_file)
logger.info('Checkpoint file: %s' % ckpt_file)
logger.info("Metagraph file: %s" % meta_file)
logger.info("Checkpoint file: %s" % ckpt_file)
saver = tf.compat.v1.train.import_meta_graph(
os.path.join(model_exp, meta_file))
saver.restore(self.session,
os.path.join(model_exp, ckpt_file))
os.path.join(model_exp, meta_file)
)
saver.restore(self.session, os.path.join(model_exp, ckpt_file))
# Get input and output tensors
self.images_placeholder = self.graph.get_tensor_by_name("input:0")
self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
self.phase_train_placeholder = self.graph.get_tensor_by_name(
"phase_train:0")
self.phase_train_placeholder = self.graph.get_tensor_by_name("phase_train:0")
logger.info("Successfully loaded the model.")
def __call__(self, img):
......@@ -149,10 +148,11 @@ class FaceNet(object):
self.session = tf.compat.v1.Session(graph=self.graph)
if self.embeddings is None:
self.load_model()
feed_dict = {self.images_placeholder: images,
self.phase_train_placeholder: False}
features = self.session.run(
self.embeddings, feed_dict=feed_dict)
feed_dict = {
self.images_placeholder: images,
self.phase_train_placeholder: False,
}
features = self.session.run(self.embeddings, feed_dict=feed_dict)
return features.flatten()
@staticmethod
......@@ -170,7 +170,9 @@ class FaceNet(object):
if model_path is None:
import pkg_resources
model_path = pkg_resources.resource_filename(
__name__, 'data/FaceNet/20170512-110547')
__name__, "data/FaceNet/20170512-110547"
)
return model_path
#!/usr/bin/env python
def get_config():
"""Returns a string containing the configuration information.
"""
import bob.extension
return bob.extension.get_config(__name__)
from .FaceNet import FaceNet
from .MTCNN import MTCNN
from .Extractor import Extractor
# gets sphinx autodoc done right - don't remove it
......@@ -27,10 +30,7 @@ def __appropriate__(*args):
obj.__module__ = __name__
__appropriate__(
FaceNet,
MTCNN,
)
__appropriate__(FaceNet, MTCNN, Extractor)
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
__all__ = [_ for _ in dir() if not _.startswith("_")]
This diff is collapsed.
......@@ -7,6 +7,7 @@ Classes
-------
.. autosummary::
bob.ip.tensorflow_extractor.Extractor
bob.ip.tensorflow_extractor.FaceNet
bob.ip.tensorflow_extractor.MTCNN
......
......@@ -34,42 +34,38 @@
# administrative interventions.
from setuptools import setup, dist
dist.Distribution(dict(setup_requires=['bob.extension']))
dist.Distribution(dict(setup_requires=["bob.extension"]))
from bob.extension.utils import load_requirements, find_packages
install_requires = load_requirements()
# The only thing we do in this file is to call the setup() function with all
# parameters that define our package.
setup(
# This is the basic information about your project. Modify all this
# information before releasing code publicly.
name = 'bob.ip.tensorflow_extractor',
version = open("version.txt").read().rstrip(),
description = 'Feature extractor using tensorflow CNNs',
url = 'https://gitlab.idiap.ch/tiago.pereira/bob.ip.caffe_extractor',
license = 'BSD',
author = 'Tiago de Freitas Pereira',
author_email = 'tiago.pereira@idiap.ch',
keywords = 'bob, biometric recognition, evaluation',
name="bob.ip.tensorflow_extractor",
version=open("version.txt").read().rstrip(),
description="Feature extractor using tensorflow CNNs",
url="https://gitlab.idiap.ch/tiago.pereira/bob.ip.caffe_extractor",
license="BSD",
author="Tiago de Freitas Pereira",
author_email="tiago.pereira@idiap.ch",
keywords="bob, biometric recognition, evaluation",
# If you have a better, long description of your package, place it on the
# 'doc' directory and then hook it here
long_description = open('README.rst').read(),
long_description=open("README.rst").read(),
# This line is required for any distutils based packaging.
packages = find_packages(),
include_package_data = True,
packages=find_packages(),
include_package_data=True,
# This line defines which packages should be installed when you "install"
# this package. All packages that are mentioned here, but are not installed
# on the current system will be installed locally and only visible to the
# scripts of this package. Don't worry - You won't need administrative
# privileges when using buildout.
install_requires = install_requires,
install_requires=install_requires,
# Your project should be called something like 'bob.<foo>' or
# 'bob.<foo>.<bar>'. To implement this correctly and still get all your
# packages to be imported w/o problems, you need to implement namespaces
......@@ -80,8 +76,6 @@ setup(
# Our database packages are good examples of namespace implementations
# using several layers. You can check them out here:
# https://github.com/idiap/bob/wiki/Satellite-Packages
# This entry defines which scripts you will have inside the 'bin' directory
# once you install the package (or run 'bin/buildout'). The order of each
# entry under 'console_scripts' is like this:
......@@ -93,17 +87,16 @@ setup(
# installed under 'example/foo.py' that contains a function which
# implements the 'main()' function of particular script you want to have
# should be referred as 'example.foo:main'.
# Classifiers are important if you plan to distribute this package through
# PyPI. You can find the complete list of classifiers that are valid and
# useful here (http://pypi.python.org/pypi?%3Aaction=list_classifiers).
classifiers = [
'Framework :: Bob',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'License :: OSI Approved :: BSD License',
'Natural Language :: English',
'Programming Language :: Python',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
classifiers=[
"Framework :: Bob",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Natural Language :: English",
"Programming Language :: Python",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment