Skip to content
Snippets Groups Projects
Commit 3fecd074 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Attempt to make the __call__/serialization work

parent bc881775
Branches
No related tags found
1 merge request!17WIP: Added ArcFace model
Pipeline #39969 failed
...@@ -9,7 +9,6 @@ from bob.io.image import to_matplotlib ...@@ -9,7 +9,6 @@ from bob.io.image import to_matplotlib
from bob.extension import rc from bob.extension import rc
import bob.extension.download import bob.extension.download
import bob.io.base import bob.io.base
import multiprocessing
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -46,7 +45,7 @@ def get_model_filenames(model_dir): ...@@ -46,7 +45,7 @@ def get_model_filenames(model_dir):
ckpt_file = step_str.groups()[0] ckpt_file = step_str.groups()[0]
return meta_file, ckpt_file return meta_file, ckpt_file
_semaphore = multiprocessing.Semaphore()
class FaceNet(object): class FaceNet(object):
"""Wrapper for the free FaceNet variant: """Wrapper for the free FaceNet variant:
https://github.com/davidsandberg/facenet https://github.com/davidsandberg/facenet
...@@ -86,19 +85,16 @@ class FaceNet(object): ...@@ -86,19 +85,16 @@ class FaceNet(object):
super(FaceNet, self).__init__() super(FaceNet, self).__init__()
self.model_path = model_path self.model_path = model_path
self.image_size = image_size self.image_size = image_size
self.layer_name = layer_name self.layer_name = layer_name
self.loaded = False
self._clean_unpicklables() self._clean_unpicklables()
def _clean_unpicklables(self): def _clean_unpicklables(self):
self.session = None self.session = None
self.embeddings = None self.embeddings = None
self.graph = None self.graph = None
self.images_placeholder = None self.images_placeholder = None
self.embeddings = None
self.phase_train_placeholder = None self.phase_train_placeholder = None
self.session = None
def _check_feature(self, img): def _check_feature(self, img):
img = numpy.ascontiguousarray(img) img = numpy.ascontiguousarray(img)
...@@ -111,6 +107,9 @@ class FaceNet(object): ...@@ -111,6 +107,9 @@ class FaceNet(object):
return img[None, ...] return img[None, ...]
def load_model(self): def load_model(self):
self.graph = tf.Graph()
self.session = tf.compat.v1.Session(graph=self.graph)
if self.model_path is None: if self.model_path is None:
self.model_path = self.get_modelpath() self.model_path = self.get_modelpath()
if not os.path.exists(self.model_path): if not os.path.exists(self.model_path):
...@@ -151,20 +150,19 @@ class FaceNet(object): ...@@ -151,20 +150,19 @@ class FaceNet(object):
self.embeddings = self.graph.get_tensor_by_name(self.layer_name) self.embeddings = self.graph.get_tensor_by_name(self.layer_name)
self.phase_train_placeholder = self.graph.get_tensor_by_name("phase_train:0") self.phase_train_placeholder = self.graph.get_tensor_by_name("phase_train:0")
logger.info("Successfully loaded the model.") logger.info("Successfully loaded the model.")
self.loaded = True
def __call__(self, img): def __call__(self, img):
with _semaphore: # with _semaphore:
images = self._check_feature(img) images = self._check_feature(img)
if self.session is None: if not self.loaded:
self.graph = tf.Graph() self.load_model()
self.session = tf.compat.v1.Session(graph=self.graph)
if self.embeddings is None: feed_dict = {
self.load_model() self.images_placeholder: images,
feed_dict = { self.phase_train_placeholder: False,
self.images_placeholder: images, }
self.phase_train_placeholder: False, features = self.session.run(self.embeddings, feed_dict=feed_dict)
}
features = self.session.run(self.embeddings, feed_dict=feed_dict)
return features.flatten() return features.flatten()
@staticmethod @staticmethod
...@@ -192,9 +190,18 @@ class FaceNet(object): ...@@ -192,9 +190,18 @@ class FaceNet(object):
def __setstate__(self, d): def __setstate__(self, d):
# Handling unpicklable objects # Handling unpicklable objects
self.__dict__ = d self.__dict__ = d
self.load_model()
def __getstate__(self): def __getstate__(self):
# Handling unpicklable objects # Handling unpicklable objects
with _semaphore: # with _semaphore:
self._clean_unpicklables() # self._clean_unpicklables()
return self.__dict__ self.loaded = False
d = self.__dict__
d.pop("session") if "session" in d else None
d.pop("embeddings") if "embeddings" in d else None
d.pop("graph") if "graph" in d else None
d.pop("images_placeholder") if "images_placeholder" in d else None
d.pop("phase_train_placeholder") if "phase_train_placeholder" in d else None
return d
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment