Commit 83643ce4 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed some issues with SampleBatch

parent 105cf73b
......@@ -92,6 +92,9 @@ def test_idiap_inceptionv1_casia():
def test_arface_insight_tf():
import tensorflow as tf
tf.compat.v1.reset_default_graph()
from bob.bio.face.transformers import ArcFace_InsightFaceTF
np.random.seed(10)
......
......@@ -6,6 +6,7 @@ from sklearn.base import TransformerMixin, BaseEstimator
from .tensorflow_compat_v1 import TensorflowCompatV1
from bob.io.image import to_matplotlib
import numpy as np
from sklearn.utils import check_array
class ArcFace_InsightFaceTF(TensorflowCompatV1):
......@@ -37,7 +38,7 @@ class ArcFace_InsightFaceTF(TensorflowCompatV1):
def transform(self, data):
# https://github.com/luckycallor/InsightFace-tensorflow/blob/master/evaluate.py#L42
data = np.asarray(data)
data = check_array(data, allow_nd=True)
data = data / 127.5 - 1.0
return super().transform(data)
......@@ -45,15 +46,12 @@ class ArcFace_InsightFaceTF(TensorflowCompatV1):
def load_model(self):
self.input_tensor = tf.compat.v1.placeholder(
dtype=tf.float32,
shape=self.input_shape,
name="input_image",
dtype=tf.float32, shape=self.input_shape, name="input_image",
)
prelogits = self.architecture_fn(self.input_tensor)
self.embedding = prelogits
# Initializing the variables of the current graph
self.session = tf.compat.v1.Session()
self.session.run(tf.compat.v1.global_variables_initializer())
......@@ -67,7 +65,9 @@ class ArcFace_InsightFaceTF(TensorflowCompatV1):
tf.train.latest_checkpoint(os.path.dirname(self.checkpoint_filename)),
)
elif os.path.isdir(self.checkpoint_filename):
saver.restore(self.session, tf.train.latest_checkpoint(self.checkpoint_filename))
saver.restore(
self.session, tf.train.latest_checkpoint(self.checkpoint_filename)
)
else:
saver.restore(self.session, self.checkpoint_filename)
......
......@@ -17,12 +17,12 @@ import os
import re
import logging
import numpy as np
import tensorflow as tf
from bob.ip.color import gray_to_rgb
from bob.io.image import to_matplotlib
from bob.extension import rc
import bob.extension.download
import bob.io.base
from sklearn.utils import check_array
logger = logging.getLogger(__name__)
......@@ -104,7 +104,7 @@ class FaceNetSanderberg(TransformerMixin, BaseEstimator):
self.phase_train_placeholder = None
def _check_feature(self, img):
img = np.asarray(img)
img = check_array(img, allow_nd=True)
def _convert(img):
assert img.shape[-2] == self.image_size
......@@ -122,6 +122,8 @@ class FaceNetSanderberg(TransformerMixin, BaseEstimator):
raise ValueError(f"Image shape {img.shape} not supported")
def load_model(self):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
session_conf = tf.compat.v1.ConfigProto(
......@@ -222,6 +224,8 @@ class FaceNetSanderberg(TransformerMixin, BaseEstimator):
self.loaded = False
def __getstate__(self):
import tensorflow as tf
# Handling unpicklable objects
d = self.__dict__
d.pop("session") if "session" in d else None
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
import tensorflow as tf
import os
from tensorflow.python import debug as tf_debug
import pkg_resources
import bob.extension.download
from bob.extension import rc
from sklearn.base import TransformerMixin, BaseEstimator
import numpy as np
import logging
from sklearn.utils import check_array
logger = logging.getLogger(__name__)
......@@ -55,36 +54,34 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
"""
data = np.asarray(data)
data = check_array(data, allow_nd=True)
# THE INPUT SHAPE FOR THESE MODELS
# ARE `N x C x H x W`
# If ndim==3 we add another axis
if data.ndim==3:
if data.ndim == 3:
data = data[None, ...]
# Making sure it's channels last and has three chanbels
if data.ndim==4:
if data.ndim == 4:
# Just swiping the second dimention
if data.shape[1] == 3:
data = np.moveaxis(data, 1, -1)
if data.shape != self.input_shape:
raise ValueError(f"Image shape {data.shape} not supported. Expected {self.input_shape}")
raise ValueError(
f"Image shape {data.shape} not supported. Expected {self.input_shape}"
)
if not self.loaded:
self.load_model()
return self.session.run(
self.embedding,
feed_dict={self.input_tensor: data},
)
return self.session.run(self.embedding, feed_dict={self.input_tensor: data},)
def load_model(self):
import tensorflow as tf
logger.info(f"Loading model `{self.checkpoint_filename}`")
tf.compat.v1.reset_default_graph()
......@@ -129,6 +126,8 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
self.loaded = False
def __getstate__(self):
import tensorflow as tf
# Handling unpicklable objects
d = self.__dict__
d.pop("session", None)
......@@ -137,7 +136,7 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
tf.compat.v1.reset_default_graph()
return d
#def __del__(self):
# def __del__(self):
# tf.compat.v1.reset_default_graph()
def get_modelpath(self, bob_rc_variable, model_subdirectory):
......@@ -183,7 +182,6 @@ class TensorflowCompatV1(TransformerMixin, BaseEstimator):
zip_file = os.path.join(model_path, zip_file)
bob.extension.download.download_and_unzip(urls, zip_file)
def fit(self, X, y=None):
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment