Commit a34446f5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Improve graph and session handling in facenet class

parent 2aa5ca35
......@@ -115,23 +115,24 @@ class FaceNet(object):
# code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)):
logger.info('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
logger.info('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
logger.info('Metagraph file: %s' % meta_file)
logger.info('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(
os.path.join(model_exp, meta_file))
saver.restore(tf.get_default_session(),
os.path.join(model_exp, ckpt_file))
with self.graph.as_default():
if (os.path.isfile(model_exp)):
logger.info('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
logger.info('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
logger.info('Metagraph file: %s' % meta_file)
logger.info('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(
os.path.join(model_exp, meta_file))
saver.restore(self.session,
os.path.join(model_exp, ckpt_file))
# Get input and output tensors
self.images_placeholder = self.graph.get_tensor_by_name("input:0")
self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
......@@ -142,8 +143,8 @@ class FaceNet(object):
def __call__(self, img):
images = self._check_feature(img)
if self.session is None:
self.session = tf.InteractiveSession()
self.graph = tf.get_default_graph()
self.graph = tf.Graph()
self.session = tf.Session(graph=self.graph)
if self.embeddings is None:
self.load_model()
feed_dict = {self.images_placeholder: images,
......@@ -152,9 +153,6 @@ class FaceNet(object):
self.embeddings, feed_dict=feed_dict)
return features.flatten()
def __del__(self):
tf.reset_default_graph()
@staticmethod
def get_rcvariable():
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment