Commit a34446f5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Improve graph and session handling in facenet class

parent 2aa5ca35
...@@ -115,23 +115,24 @@ class FaceNet(object): ...@@ -115,23 +115,24 @@ class FaceNet(object):
# code from https://github.com/davidsandberg/facenet # code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path) model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)): with self.graph.as_default():
logger.info('Model filename: %s' % model_exp) if (os.path.isfile(model_exp)):
with tf.gfile.FastGFile(model_exp, 'rb') as f: logger.info('Model filename: %s' % model_exp)
graph_def = tf.GraphDef() with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def.ParseFromString(f.read()) graph_def = tf.GraphDef()
tf.import_graph_def(graph_def, name='') graph_def.ParseFromString(f.read())
else: tf.import_graph_def(graph_def, name='')
logger.info('Model directory: %s' % model_exp) else:
meta_file, ckpt_file = get_model_filenames(model_exp) logger.info('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
logger.info('Metagraph file: %s' % meta_file)
logger.info('Checkpoint file: %s' % ckpt_file) logger.info('Metagraph file: %s' % meta_file)
logger.info('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(
os.path.join(model_exp, meta_file)) saver = tf.train.import_meta_graph(
saver.restore(tf.get_default_session(), os.path.join(model_exp, meta_file))
os.path.join(model_exp, ckpt_file)) saver.restore(self.session,
os.path.join(model_exp, ckpt_file))
# Get input and output tensors # Get input and output tensors
self.images_placeholder = self.graph.get_tensor_by_name("input:0") self.images_placeholder = self.graph.get_tensor_by_name("input:0")
self.embeddings = self.graph.get_tensor_by_name(self.layer_name) self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
...@@ -142,8 +143,8 @@ class FaceNet(object): ...@@ -142,8 +143,8 @@ class FaceNet(object):
def __call__(self, img): def __call__(self, img):
images = self._check_feature(img) images = self._check_feature(img)
if self.session is None: if self.session is None:
self.session = tf.InteractiveSession() self.graph = tf.Graph()
self.graph = tf.get_default_graph() self.session = tf.Session(graph=self.graph)
if self.embeddings is None: if self.embeddings is None:
self.load_model() self.load_model()
feed_dict = {self.images_placeholder: images, feed_dict = {self.images_placeholder: images,
...@@ -152,9 +153,6 @@ class FaceNet(object): ...@@ -152,9 +153,6 @@ class FaceNet(object):
self.embeddings, feed_dict=feed_dict) self.embeddings, feed_dict=feed_dict)
return features.flatten() return features.flatten()
def __del__(self):
tf.reset_default_graph()
@staticmethod @staticmethod
def get_rcvariable(): def get_rcvariable():
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment