Commit ff821c8a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'facenet' into 'master'

Improve graph and session handling in facenet class

See merge request !13
parents 2aa5ca35 a34446f5
Pipeline #30799 passed with stages
in 17 minutes and 48 seconds
......@@ -115,23 +115,24 @@ class FaceNet(object):
# code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)):
logger.info('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
logger.info('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
logger.info('Metagraph file: %s' % meta_file)
logger.info('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(
os.path.join(model_exp, meta_file))
saver.restore(tf.get_default_session(),
os.path.join(model_exp, ckpt_file))
with self.graph.as_default():
if (os.path.isfile(model_exp)):
logger.info('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
logger.info('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
logger.info('Metagraph file: %s' % meta_file)
logger.info('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(
os.path.join(model_exp, meta_file))
saver.restore(self.session,
os.path.join(model_exp, ckpt_file))
# Get input and output tensors
self.images_placeholder = self.graph.get_tensor_by_name("input:0")
self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
......@@ -142,8 +143,8 @@ class FaceNet(object):
def __call__(self, img):
images = self._check_feature(img)
if self.session is None:
self.session = tf.InteractiveSession()
self.graph = tf.get_default_graph()
self.graph = tf.Graph()
self.session = tf.Session(graph=self.graph)
if self.embeddings is None:
self.load_model()
feed_dict = {self.images_placeholder: images,
......@@ -152,9 +153,6 @@ class FaceNet(object):
self.embeddings, feed_dict=feed_dict)
return features.flatten()
def __del__(self):
tf.reset_default_graph()
@staticmethod
def get_rcvariable():
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment