diff --git a/bob/ip/tensorflow_extractor/FaceNet.py b/bob/ip/tensorflow_extractor/FaceNet.py index d7a58e591fbe788e41df3ca9d0cdee1a90ad8ec6..316b0f0854e85ae810de0bed2736141adc2f49d8 100644 --- a/bob/ip/tensorflow_extractor/FaceNet.py +++ b/bob/ip/tensorflow_extractor/FaceNet.py @@ -9,6 +9,7 @@ from bob.io.image import to_matplotlib from bob.extension import rc import bob.extension.download import bob.io.base +import multiprocessing logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ def get_model_filenames(model_dir): ckpt_file = step_str.groups()[0] return meta_file, ckpt_file - +_semaphore = multiprocessing.Semaphore() class FaceNet(object): """Wrapper for the free FaceNet variant: https://github.com/davidsandberg/facenet @@ -85,9 +86,19 @@ class FaceNet(object): super(FaceNet, self).__init__() self.model_path = model_path self.image_size = image_size + self.layer_name = layer_name + self._clean_unpicklables() + + + def _clean_unpicklables(self): self.session = None self.embeddings = None - self.layer_name = layer_name + self.graph = None + self.images_placeholder = None + self.embeddings = None + self.phase_train_placeholder = None + self.session = None + def _check_feature(self, img): img = numpy.ascontiguousarray(img) @@ -142,17 +153,18 @@ class FaceNet(object): logger.info("Successfully loaded the model.") def __call__(self, img): - images = self._check_feature(img) - if self.session is None: - self.graph = tf.Graph() - self.session = tf.compat.v1.Session(graph=self.graph) - if self.embeddings is None: - self.load_model() - feed_dict = { - self.images_placeholder: images, - self.phase_train_placeholder: False, - } - features = self.session.run(self.embeddings, feed_dict=feed_dict) + with _semaphore: + images = self._check_feature(img) + if self.session is None: + self.graph = tf.Graph() + self.session = tf.compat.v1.Session(graph=self.graph) + if self.embeddings is None: + self.load_model() + feed_dict = { + self.images_placeholder: images, + self.phase_train_placeholder: False, + } + features = self.session.run(self.embeddings, feed_dict=feed_dict) return features.flatten() @staticmethod @@ -176,3 +188,13 @@ class FaceNet(object): ) return model_path + + def __setstate__(self, d): + # Handling unpicklable objects + self.__dict__ = d + + def __getstate__(self): + # Handling unpicklable objects + with _semaphore: + self._clean_unpicklables() + return self.__dict__