Commit c592bf92 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'dask-pipelines' into 'master'

Make FaceNet pickalable

See merge request !16
parents 00042f9a f5ecbee7
Pipeline #39261 passed with stages
in 6 minutes and 35 seconds
......@@ -9,6 +9,7 @@ from bob.io.image import to_matplotlib
from bob.extension import rc
import bob.extension.download
import bob.io.base
import multiprocessing
logger = logging.getLogger(__name__)
......@@ -45,7 +46,7 @@ def get_model_filenames(model_dir):
ckpt_file = step_str.groups()[0]
return meta_file, ckpt_file
_semaphore = multiprocessing.Semaphore()
class FaceNet(object):
"""Wrapper for the free FaceNet variant:
https://github.com/davidsandberg/facenet
......@@ -85,9 +86,19 @@ class FaceNet(object):
super(FaceNet, self).__init__()
self.model_path = model_path
self.image_size = image_size
self.layer_name = layer_name
self._clean_unpicklables()
def _clean_unpicklables(self):
self.session = None
self.embeddings = None
self.layer_name = layer_name
self.graph = None
self.images_placeholder = None
self.embeddings = None
self.phase_train_placeholder = None
self.session = None
def _check_feature(self, img):
img = numpy.ascontiguousarray(img)
......@@ -142,17 +153,18 @@ class FaceNet(object):
logger.info("Successfully loaded the model.")
def __call__(self, img):
images = self._check_feature(img)
if self.session is None:
self.graph = tf.Graph()
self.session = tf.compat.v1.Session(graph=self.graph)
if self.embeddings is None:
self.load_model()
feed_dict = {
self.images_placeholder: images,
self.phase_train_placeholder: False,
}
features = self.session.run(self.embeddings, feed_dict=feed_dict)
with _semaphore:
images = self._check_feature(img)
if self.session is None:
self.graph = tf.Graph()
self.session = tf.compat.v1.Session(graph=self.graph)
if self.embeddings is None:
self.load_model()
feed_dict = {
self.images_placeholder: images,
self.phase_train_placeholder: False,
}
features = self.session.run(self.embeddings, feed_dict=feed_dict)
return features.flatten()
@staticmethod
......@@ -176,3 +188,13 @@ class FaceNet(object):
)
return model_path
def __setstate__(self, d):
# Handling unpicklable objects
self.__dict__ = d
def __getstate__(self):
# Handling unpicklable objects
with _semaphore:
self._clean_unpicklables()
return self.__dict__
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment