Skip to content
Snippets Groups Projects
Commit a351280b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[facenet] Allow for layer name to be provided as well

parent dca9c583
No related branches found
No related tags found
1 merge request!10[facenet] Allow for layer name to be provided as well
Pipeline #
...@@ -73,15 +73,18 @@ class FaceNet(object): ...@@ -73,15 +73,18 @@ class FaceNet(object):
""" """
def __init__(self, def __init__(
model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"], self,
image_size=160, model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"],
**kwargs): image_size=160,
layer_name='embeddings:0',
**kwargs):
super(FaceNet, self).__init__() super(FaceNet, self).__init__()
self.model_path = model_path self.model_path = model_path
self.image_size = image_size self.image_size = image_size
self.session = None self.session = None
self.embeddings = None self.embeddings = None
self.layer_name = layer_name
def _check_feature(self, img): def _check_feature(self, img):
img = numpy.ascontiguousarray(img) img = numpy.ascontiguousarray(img)
...@@ -99,7 +102,7 @@ class FaceNet(object): ...@@ -99,7 +102,7 @@ class FaceNet(object):
if not os.path.exists(self.model_path): if not os.path.exists(self.model_path):
bob.io.base.create_directories_safe(FaceNet.get_modelpath()) bob.io.base.create_directories_safe(FaceNet.get_modelpath())
zip_file = os.path.join(FaceNet.get_modelpath(), zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip") "20170512-110547.zip")
urls = [ urls = [
# This is a private link at Idiap to save bandwidth. # This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/" "http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
...@@ -107,9 +110,9 @@ class FaceNet(object): ...@@ -107,9 +110,9 @@ class FaceNet(object):
# this works for everybody # this works for everybody
"https://drive.google.com/uc?export=download&id=" "https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk", "0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
] ]
bob.extension.download.download_and_unzip(urls, zip_file) bob.extension.download.download_and_unzip(urls, zip_file)
# code from https://github.com/davidsandberg/facenet # code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path) model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)): if (os.path.isfile(model_exp)):
...@@ -131,7 +134,7 @@ class FaceNet(object): ...@@ -131,7 +134,7 @@ class FaceNet(object):
os.path.join(model_exp, ckpt_file)) os.path.join(model_exp, ckpt_file))
# Get input and output tensors # Get input and output tensors
self.images_placeholder = self.graph.get_tensor_by_name("input:0") self.images_placeholder = self.graph.get_tensor_by_name("input:0")
self.embeddings = self.graph.get_tensor_by_name("embeddings:0") self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
self.phase_train_placeholder = self.graph.get_tensor_by_name( self.phase_train_placeholder = self.graph.get_tensor_by_name(
"phase_train:0") "phase_train:0")
logger.info("Successfully loaded the model.") logger.info("Successfully loaded the model.")
...@@ -156,26 +159,26 @@ class FaceNet(object): ...@@ -156,26 +159,26 @@ class FaceNet(object):
def get_rcvariable(): def get_rcvariable():
""" """
Variable name used in the Bob Global Configuration System Variable name used in the Bob Global Configuration System
https://www.idiap.ch/software/bob/docs/bob/bob.extension/stable/rc.html#global-configuration-system https://www.idiap.ch/software/bob/docs/bob/bob.extension/stable/rc.html
""" """
return "bob.ip.tensorflow_extractor.facenet_modelpath" return "bob.ip.tensorflow_extractor.facenet_modelpath"
@staticmethod @staticmethod
def get_modelpath(): def get_modelpath():
""" """
Get default model path. Get default model path.
First we try the to search this path via Global Configuration System. First we try the to search this path via Global Configuration System.
If we can not find it, we set the path in the directory `<project>/data` If we can not find it, we set the path in the directory
`<project>/data`
""" """
# Priority to the RC path # Priority to the RC path
model_path = rc[FaceNet.get_rcvariable()] model_path = rc[FaceNet.get_rcvariable()]
if model_path is None: if model_path is None:
import pkg_resources import pkg_resources
model_path = pkg_resources.resource_filename(__name__, model_path = pkg_resources.resource_filename(
'data/FaceNet/20170512-110547') __name__, 'data/FaceNet/20170512-110547')
return model_path return model_path
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment