diff --git a/bob/ip/tensorflow_extractor/FaceNet.py b/bob/ip/tensorflow_extractor/FaceNet.py index 510b00a6924e6b80e2d438f4c9f593efa707765a..4b5a05bfbf001c57f3be73c864969f50c4354270 100644 --- a/bob/ip/tensorflow_extractor/FaceNet.py +++ b/bob/ip/tensorflow_extractor/FaceNet.py @@ -115,23 +115,24 @@ class FaceNet(object): # code from https://github.com/davidsandberg/facenet model_exp = os.path.expanduser(self.model_path) - if (os.path.isfile(model_exp)): - logger.info('Model filename: %s' % model_exp) - with tf.gfile.FastGFile(model_exp, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - tf.import_graph_def(graph_def, name='') - else: - logger.info('Model directory: %s' % model_exp) - meta_file, ckpt_file = get_model_filenames(model_exp) - - logger.info('Metagraph file: %s' % meta_file) - logger.info('Checkpoint file: %s' % ckpt_file) - - saver = tf.train.import_meta_graph( - os.path.join(model_exp, meta_file)) - saver.restore(tf.get_default_session(), - os.path.join(model_exp, ckpt_file)) + with self.graph.as_default(): + if (os.path.isfile(model_exp)): + logger.info('Model filename: %s' % model_exp) + with tf.gfile.FastGFile(model_exp, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + else: + logger.info('Model directory: %s' % model_exp) + meta_file, ckpt_file = get_model_filenames(model_exp) + + logger.info('Metagraph file: %s' % meta_file) + logger.info('Checkpoint file: %s' % ckpt_file) + + saver = tf.train.import_meta_graph( + os.path.join(model_exp, meta_file)) + saver.restore(self.session, + os.path.join(model_exp, ckpt_file)) # Get input and output tensors self.images_placeholder = self.graph.get_tensor_by_name("input:0") self.embeddings = self.graph.get_tensor_by_name(self.layer_name) @@ -142,8 +143,8 @@ class FaceNet(object): def __call__(self, img): images = self._check_feature(img) if self.session is None: - self.session = tf.InteractiveSession() - self.graph = tf.get_default_graph() + self.graph = tf.Graph() + self.session = tf.Session(graph=self.graph) if self.embeddings is None: self.load_model() feed_dict = {self.images_placeholder: images, @@ -152,9 +153,6 @@ class FaceNet(object): self.embeddings, feed_dict=feed_dict) return features.flatten() - def __del__(self): - tf.reset_default_graph() - @staticmethod def get_rcvariable(): """