Skip to content
Snippets Groups Projects
Commit ae1afa0d authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Integrated Arcface from InsightFace

parent d40ff779
No related branches found
No related tags found
1 merge request!79Arcface from InsightFace
Pipeline #45370 failed
from bob.bio.face.embeddings import ArcFaceInsightFace
from bob.bio.face.config.baseline.helpers import embedding_transformer_112x112
from bob.bio.base.pipelines.vanilla_biometrics import (
Distance,
VanillaBiometricsPipeline,
)
if "database" in locals():
annotation_type = database.annotation_type
fixed_positions = database.fixed_positions
else:
annotation_type = None
fixed_positions = None
def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_112x112(
ArcFaceInsightFace(), annotation_type, fixed_positions, color_channel="rgb"
)
algorithm = Distance()
return VanillaBiometricsPipeline(transformer, algorithm)
pipeline = load(annotation_type, fixed_positions)
transformer = pipeline.transformer
\ No newline at end of file
......@@ -34,6 +34,8 @@ from .tf2_inception_resnet import (
FaceNetSanderberg_20170512_110547
)
from .mxnet_models import ArcFaceInsightFace
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
......@@ -56,6 +58,7 @@ __appropriate__(
InceptionResnetv1_MsCeleb_CenterLoss_2018,
InceptionResnetv2_Casia_CenterLoss_2018,
InceptionResnetv1_Casia_CenterLoss_2018,
FaceNetSanderberg_20170512_110547
FaceNetSanderberg_20170512_110547,
ArcFaceInsightFace
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
"""
Load and predict using checkpoints based on mxnet
"""
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.utils import check_array
import numpy as np
from bob.bio.face.embeddings import download_model
import pkg_resources
import os
from bob.extension import rc
class ArcFaceInsightFace(TransformerMixin, BaseEstimator):
"""
ArcFace from Insight Face.
Model and source code taken from the repository
https://github.com/deepinsight/insightface/blob/master/python-package/insightface/model_zoo/face_recognition.py
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = None
internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "arcface_insightface"),
)
checkpoint_path = (
internal_path
if rc["bob.bio.face.models.ArcFaceInsightFace"] is None
else rc["bob.bio.face.models.ArcFaceInsightFace"]
)
urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/mxnet/arcface_r100_v1_mxnet.tar.gz"
]
download_model(checkpoint_path, urls, "arcface_r100_v1_mxnet.tar.gz")
self.checkpoint_path = checkpoint_path
def load_model(self):
import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint(
os.path.join(self.checkpoint_path, "model"), 0
)
all_layers = sym.get_internals()
sym = all_layers["fc1_output"]
# LOADING CHECKPOINT
ctx = mx.cpu()
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
data_shape = (1, 3, 112, 112)
model.bind(data_shapes=[("data", data_shape)])
model.set_params(arg_params, aux_params)
# warmup
data = mx.nd.zeros(shape=data_shape)
db = mx.io.DataBatch(data=(data,))
model.forward(db, is_train=False)
embedding = model.get_outputs()[0].asnumpy()
self.model = model
def transform(self, X):
import mxnet as mx
if self.model is None:
self.load_model()
X = check_array(X, allow_nd=True)
X = mx.nd.array(X)
db = mx.io.DataBatch(data=(X,))
self.model.forward(db, is_train=False)
return self.model.get_outputs()[0].asnumpy()
def __setstate__(self, d):
self.__dict__ = d
def __getstate__(self):
# Handling unpicklable objects
d = self.__dict__.copy()
d["model"] = None
return d
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
......@@ -49,7 +49,7 @@ class InceptionResnet(TransformerMixin, BaseEstimator):
if self.preprocessor is not None:
X = self.preprocessor(tf.cast(X, "float32"))
prelogits = self.model(X, training=False)
prelogits = self.model.predict_on_batch(X)
embeddings = tf.math.l2_normalize(prelogits, axis=-1)
return embeddings
......@@ -74,7 +74,7 @@ class InceptionResnet(TransformerMixin, BaseEstimator):
def __getstate__(self):
# Handling unpicklable objects
d = self.__dict__.copy()
d["model"] = None
d["model"] = None
return d
def _more_tags(self):
......
......@@ -125,14 +125,8 @@ def test_inception_resnetv1_msceleb():
def test_inception_resnetv1_casiawebface():
run_baseline("inception-resnetv1-casiawebface")
"""
def test_arcface_insight_tf():
import tensorflow as tf
tf.compat.v1.reset_default_graph()
run_baseline("arcface-insight-tf")
"""
def test_arcface_insightface():
run_baseline("arcface-insightface")
def test_gabor_graph():
run_baseline("gabor_graph")
......
......@@ -94,6 +94,22 @@ def test_facenet_sanderberg():
assert output.size == 128, output.shape
def test_arcface_insight_face():
from bob.bio.face.embeddings import ArcFaceInsightFace
transformer = ArcFaceInsightFace()
data = np.random.rand(3, 112, 112)*255
data = data.astype("uint8")
output = transformer.transform([data])
assert output.size == 512, output.shape
# Sample Batch
sample = Sample(data)
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
assert output.size == 512, output.shape
"""
def test_arface_insight_tf():
import tensorflow as tf
......
......@@ -139,7 +139,7 @@ setup(
'inception-resnetv2-casiawebface = bob.bio.face.config.baseline.inception_resnetv2_casiawebface:transformer',
'inception-resnetv1-msceleb = bob.bio.face.config.baseline.inception_resnetv1_msceleb:transformer',
'inception-resnetv2-msceleb = bob.bio.face.config.baseline.inception_resnetv2_msceleb:transformer',
'arcface-insight-tf = bob.bio.face.config.baseline.arcface_insight_tf:transformer',
'arcface-insightface = bob.bio.face.config.baseline.arcface_insightface:transformer',
'gabor-graph = bob.bio.face.config.baseline.gabor_graph:transformer',
'lgbphs = bob.bio.face.config.baseline.lgbphs:transformer',
'dummy = bob.bio.face.config.baseline.dummy:transformer',
......@@ -153,7 +153,7 @@ setup(
'inception-resnetv1-msceleb = bob.bio.face.config.baseline.inception_resnetv1_msceleb:pipeline',
'inception-resnetv2-msceleb = bob.bio.face.config.baseline.inception_resnetv2_msceleb:pipeline',
'gabor_graph = bob.bio.face.config.baseline.gabor_graph:pipeline',
'arcface-insight-tf = bob.bio.face.config.baseline.arcface_insight_tf:pipeline',
'arcface-insightface = bob.bio.face.config.baseline.arcface_insightface:pipeline',
'lgbphs = bob.bio.face.config.baseline.lgbphs:pipeline',
'lda = bob.bio.face.config.baseline.lda:pipeline',
'dummy = bob.bio.face.config.baseline.dummy:pipeline',
......@@ -166,7 +166,7 @@ setup(
'inception-resnetv1-msceleb = bob.bio.face.config.baseline.inception_resnetv1_msceleb',
'inception-resnetv2-msceleb = bob.bio.face.config.baseline.inception_resnetv2_msceleb',
'gabor_graph = bob.bio.face.config.baseline.gabor_graph',
'arcface-insight-tf = bob.bio.face.config.baseline.arcface_insight_tf',
'arcface-insightface = bob.bio.face.config.baseline.arcface_insightface',
'lgbphs = bob.bio.face.config.baseline.lgbphs',
'lda = bob.bio.face.config.baseline.lda',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment