Commit a351280b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

[facenet] Allow for layer name to be provided as well

parent dca9c583
Pipeline #22824 passed with stage
in 26 minutes and 51 seconds
......@@ -73,15 +73,18 @@ class FaceNet(object):
"""
def __init__(self,
model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"],
image_size=160,
**kwargs):
def __init__(
self,
model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"],
image_size=160,
layer_name='embeddings:0',
**kwargs):
super(FaceNet, self).__init__()
self.model_path = model_path
self.image_size = image_size
self.session = None
self.embeddings = None
self.layer_name = layer_name
def _check_feature(self, img):
img = numpy.ascontiguousarray(img)
......@@ -99,7 +102,7 @@ class FaceNet(object):
if not os.path.exists(self.model_path):
bob.io.base.create_directories_safe(FaceNet.get_modelpath())
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
"20170512-110547.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
......@@ -107,9 +110,9 @@ class FaceNet(object):
# this works for everybody
"https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
]
]
bob.extension.download.download_and_unzip(urls, zip_file)
# code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)):
......@@ -131,7 +134,7 @@ class FaceNet(object):
os.path.join(model_exp, ckpt_file))
# Get input and output tensors
self.images_placeholder = self.graph.get_tensor_by_name("input:0")
self.embeddings = self.graph.get_tensor_by_name("embeddings:0")
self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
self.phase_train_placeholder = self.graph.get_tensor_by_name(
"phase_train:0")
logger.info("Successfully loaded the model.")
......@@ -156,26 +159,26 @@ class FaceNet(object):
def get_rcvariable():
"""
Variable name used in the Bob Global Configuration System
https://www.idiap.ch/software/bob/docs/bob/bob.extension/stable/rc.html#global-configuration-system
https://www.idiap.ch/software/bob/docs/bob/bob.extension/stable/rc.html
"""
return "bob.ip.tensorflow_extractor.facenet_modelpath"
@staticmethod
def get_modelpath():
"""
Get default model path.
Get default model path.
First we try the to search this path via Global Configuration System.
If we can not find it, we set the path in the directory `<project>/data`
If we can not find it, we set the path in the directory
`<project>/data`
"""
# Priority to the RC path
model_path = rc[FaceNet.get_rcvariable()]
if model_path is None:
import pkg_resources
model_path = pkg_resources.resource_filename(__name__,
'data/FaceNet/20170512-110547')
model_path = pkg_resources.resource_filename(
__name__, 'data/FaceNet/20170512-110547')
return model_path
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment