diff --git a/bob/ip/tensorflow_extractor/FaceNet.py b/bob/ip/tensorflow_extractor/FaceNet.py index 169573c43d1d164b83a052f9e6526c6a45a8cbfe..510b00a6924e6b80e2d438f4c9f593efa707765a 100644 --- a/bob/ip/tensorflow_extractor/FaceNet.py +++ b/bob/ip/tensorflow_extractor/FaceNet.py @@ -73,15 +73,18 @@ class FaceNet(object): """ - def __init__(self, - model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"], - image_size=160, - **kwargs): + def __init__( + self, + model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"], + image_size=160, + layer_name='embeddings:0', + **kwargs): super(FaceNet, self).__init__() self.model_path = model_path self.image_size = image_size self.session = None self.embeddings = None + self.layer_name = layer_name def _check_feature(self, img): img = numpy.ascontiguousarray(img) @@ -99,7 +102,7 @@ class FaceNet(object): if not os.path.exists(self.model_path): bob.io.base.create_directories_safe(FaceNet.get_modelpath()) zip_file = os.path.join(FaceNet.get_modelpath(), - "20170512-110547.zip") + "20170512-110547.zip") urls = [ # This is a private link at Idiap to save bandwidth. "http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/" @@ -107,9 +110,9 @@ class FaceNet(object): # this works for everybody "https://drive.google.com/uc?export=download&id=" "0B5MzpY9kBtDVZ2RpVDYwWmxoSUk", - ] + ] bob.extension.download.download_and_unzip(urls, zip_file) - + # code from https://github.com/davidsandberg/facenet model_exp = os.path.expanduser(self.model_path) if (os.path.isfile(model_exp)): @@ -131,7 +134,7 @@ class FaceNet(object): os.path.join(model_exp, ckpt_file)) # Get input and output tensors self.images_placeholder = self.graph.get_tensor_by_name("input:0") - self.embeddings = self.graph.get_tensor_by_name("embeddings:0") + self.embeddings = self.graph.get_tensor_by_name(self.layer_name) self.phase_train_placeholder = self.graph.get_tensor_by_name( "phase_train:0") logger.info("Successfully loaded the model.") @@ -156,26 +159,26 @@ class FaceNet(object): def get_rcvariable(): """ Variable name used in the Bob Global Configuration System - https://www.idiap.ch/software/bob/docs/bob/bob.extension/stable/rc.html#global-configuration-system + https://www.idiap.ch/software/bob/docs/bob/bob.extension/stable/rc.html """ return "bob.ip.tensorflow_extractor.facenet_modelpath" @staticmethod def get_modelpath(): """ - Get default model path. + Get default model path. First we try the to search this path via Global Configuration System. - If we can not find it, we set the path in the directory `<project>/data` + If we can not find it, we set the path in the directory + `<project>/data` """ - + # Priority to the RC path model_path = rc[FaceNet.get_rcvariable()] if model_path is None: import pkg_resources - model_path = pkg_resources.resource_filename(__name__, - 'data/FaceNet/20170512-110547') + model_path = pkg_resources.resource_filename( + __name__, 'data/FaceNet/20170512-110547') return model_path -