Commit 105cf73b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Finished the transformers

parent 6f0f913a
Pipeline #40240 failed with stage
in 4 minutes and 31 seconds
......@@ -22,7 +22,7 @@ def test_facenet_pipeline():
#transformed_sample = transformer.transform([fake_sample])[0].data
#import ipdb; ipdb.set_trace()
#import ipdb; ipdb.set_trace()
transformed_sample = transformer.transform([fake_sample])[0]
assert transformed_sample.data.size == 160
......
......@@ -31,6 +31,13 @@ def test_idiap_inceptionv2_msceleb():
output = transformer.transform(data)
assert output.size == 128, output.shape
# Sample Batch
sample = Sample(data)
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
assert output.size == 128, output.shape
def test_idiap_inceptionv2_casia():
from bob.bio.face.transformers import InceptionResnetv2_CasiaWebFace
......@@ -42,6 +49,14 @@ def test_idiap_inceptionv2_casia():
assert output.size == 128, output.shape
# Sample Batch
sample = Sample(data)
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
assert output.size == 128, output.shape
def test_idiap_inceptionv1_msceleb():
from bob.bio.face.transformers import InceptionResnetv1_MsCeleb
......@@ -51,6 +66,13 @@ def test_idiap_inceptionv1_msceleb():
output = transformer.transform(data)
assert output.size == 128, output.shape
# Sample Batch
sample = Sample(data)
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
assert output.size == 128, output.shape
def test_idiap_inceptionv1_casia():
from bob.bio.face.transformers import InceptionResnetv1_CasiaWebFace
......@@ -61,6 +83,13 @@ def test_idiap_inceptionv1_casia():
output = transformer.transform(data)
assert output.size == 128, output.shape
# Sample Batch
sample = Sample(data)
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
assert output.size == 128, output.shape
def test_arface_insight_tf():
from bob.bio.face.transformers import ArcFace_InsightFaceTF
......@@ -69,4 +98,10 @@ def test_arface_insight_tf():
transformer = ArcFace_InsightFaceTF()
data = np.random.rand(3, 112, 112).astype("uint8")
output = transformer.transform(data)
assert output.size == 512, output.shape
\ No newline at end of file
assert output.size == 512, output.shape
# Sample Batch
sample = Sample(data)
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
assert output.size == 512, output.shape
......@@ -5,6 +5,8 @@ import os
from sklearn.base import TransformerMixin, BaseEstimator
from .tensorflow_compat_v1 import TensorflowCompatV1
from bob.io.image import to_matplotlib
import numpy as np
class ArcFace_InsightFaceTF(TensorflowCompatV1):
"""
......@@ -19,10 +21,11 @@ class ArcFace_InsightFaceTF(TensorflowCompatV1):
def __init__(self):
bob_rc_variable = "bob.bio.face.arcface_tf_path"
urls = ["https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/arcface_insight_tf.tar.gz"]
urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/arcface_insight_tf.tar.gz"
]
model_subdirectory = "arcface_tf_path"
checkpoint_filename = self.get_modelpath(bob_rc_variable, model_subdirectory)
self.download_model(checkpoint_filename, urls)
......@@ -34,11 +37,42 @@ class ArcFace_InsightFaceTF(TensorflowCompatV1):
def transform(self, data):
# https://github.com/luckycallor/InsightFace-tensorflow/blob/master/evaluate.py#L42
data = to_matplotlib(data)
data = np.asarray(data)
data = data / 127.5 - 1.0
return super().transform(data)
def load_model(self):
self.input_tensor = tf.compat.v1.placeholder(
dtype=tf.float32,
shape=self.input_shape,
name="input_image",
)
prelogits = self.architecture_fn(self.input_tensor)
self.embedding = prelogits
# Initializing the variables of the current graph
self.session = tf.compat.v1.Session()
self.session.run(tf.compat.v1.global_variables_initializer())
# Loading the last checkpoint and overwriting the current variables
saver = tf.compat.v1.train.Saver()
if os.path.splitext(self.checkpoint_filename)[1] == ".meta":
saver.restore(
self.session,
tf.train.latest_checkpoint(os.path.dirname(self.checkpoint_filename)),
)
elif os.path.isdir(self.checkpoint_filename):
saver.restore(self.session, tf.train.latest_checkpoint(self.checkpoint_filename))
else:
saver.restore(self.session, self.checkpoint_filename)
self.loaded = True
###########################
# CODE COPIED FROM
......@@ -49,7 +83,8 @@ import tensorflow as tf
import tensorflow.contrib.slim as slim
from collections import namedtuple
def init_network(input_tensor, model=None):
def init_network(input_tensor):
with tf.variable_scope("embd_extractor", reuse=False):
arg_sc = resnet_arg_scope()
......
......@@ -8,7 +8,7 @@ import pkg_resources
import bob.extension.download
from bob.extension import rc
from sklearn.base import TransformerMixin, BaseEstimator
import numpy as np
import logging
logger = logging.getLogger(__name__)
......@@ -54,14 +54,36 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
The features.
"""
data = np.asarray(data)
# THE INPUT SHAPE FOR THESE MODELS
# ARE `N x C x H x W`
# If ndim==3 we add another axis
if data.ndim==3:
data = data[None, ...]
# Making sure it's channels last and has three chanbels
if data.ndim==4:
# Just swiping the second dimention
if data.shape[1] == 3:
data = np.moveaxis(data, 1, -1)
if data.shape != self.input_shape:
raise ValueError(f"Image shape {data.shape} not supported. Expected {self.input_shape}")
if not self.loaded:
self.load_model()
return self.session.run(
self.embedding,
feed_dict={self.input_tensor: data.reshape(self.input_shape)},
feed_dict={self.input_tensor: data},
)
def load_model(self):
logger.info(f"Loading model `{self.checkpoint_filename}`")
......@@ -115,7 +137,7 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
tf.compat.v1.reset_default_graph()
return d
# def __del__(self):
#def __del__(self):
# tf.compat.v1.reset_default_graph()
def get_modelpath(self, bob_rc_variable, model_subdirectory):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment