Add a memory demanding feature in the TF based transformers

.
parent d4fbeef1
Pipeline #46033 passed with stage
in 36 minutes and 22 seconds
......@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline,
)
memory_demanding = False
if "database" in locals():
annotation_type = database.annotation_type
fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else:
annotation_type = None
fixed_positions = None
......@@ -19,7 +23,9 @@ else:
def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160(
FaceNetSanderberg_20170512_110547(), annotation_type, fixed_positions
FaceNetSanderberg_20170512_110547(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
)
algorithm = Distance()
......
......@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline,
)
memory_demanding = False
if "database" in locals():
annotation_type = database.annotation_type
fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else:
annotation_type = None
fixed_positions = None
......@@ -18,7 +22,9 @@ else:
def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160(
InceptionResnetv1_Casia_CenterLoss_2018(), annotation_type, fixed_positions
InceptionResnetv1_Casia_CenterLoss_2018(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
)
algorithm = Distance()
......
......@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline,
)
memory_demanding = False
if "database" in locals():
annotation_type = database.annotation_type
fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else:
annotation_type = None
fixed_positions = None
......@@ -18,7 +22,9 @@ else:
def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160(
InceptionResnetv1_MsCeleb_CenterLoss_2018(), annotation_type, fixed_positions
InceptionResnetv1_MsCeleb_CenterLoss_2018(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
)
algorithm = Distance()
......
......@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline,
)
memory_demanding = False
if "database" in locals():
annotation_type = database.annotation_type
fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else:
annotation_type = None
fixed_positions = None
......@@ -18,7 +22,9 @@ else:
def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160(
InceptionResnetv2_Casia_CenterLoss_2018(), annotation_type, fixed_positions
InceptionResnetv2_Casia_CenterLoss_2018(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
)
algorithm = Distance()
......
......@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline,
)
memory_demanding = False
if "database" in locals():
annotation_type = database.annotation_type
fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else:
annotation_type = None
fixed_positions = None
......@@ -18,7 +22,9 @@ else:
def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160(
InceptionResnetv2_MsCeleb_CenterLoss_2018(), annotation_type, fixed_positions
InceptionResnetv2_MsCeleb_CenterLoss_2018(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
)
algorithm = Distance()
......
......@@ -3,7 +3,7 @@ from bob.bio.face.embeddings.tf2_inception_resnet import InceptionResnetv2
from bob.bio.face.preprocessor import FaceCrop
from bob.bio.face.config.baseline.helpers import (
embedding_transformer_default_cropping,
embedding_transformer
embedding_transformer,
)
from sklearn.pipeline import make_pipeline
......@@ -13,32 +13,43 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline,
)
memory_demanding = False
if "database" in locals():
annotation_type = database.annotation_type
fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else:
annotation_type = None
fixed_positions = None
def load(annotation_type, fixed_positions=None):
CROPPED_IMAGE_SIZE = (160, 160)
CROPPED_POSITIONS = embedding_transformer_default_cropping(CROPPED_IMAGE_SIZE,
annotation_type=annotation_type)
CROPPED_POSITIONS = embedding_transformer_default_cropping(
CROPPED_IMAGE_SIZE, annotation_type=annotation_type
)
extractor_path = rc['bob.bio.face.tf2.casia-webface-inception-v2']
extractor_path = rc["bob.bio.face.tf2.casia-webface-inception-v2"]
embedding = InceptionResnetv2(checkpoint_path=extractor_path)
embedding = InceptionResnetv2(
checkpoint_path=extractor_path, memory_demanding=memory_demanding
)
transformer = embedding_transformer(CROPPED_IMAGE_SIZE,
embedding,
annotation_type,
CROPPED_POSITIONS,
fixed_positions)
transformer = embedding_transformer(
CROPPED_IMAGE_SIZE,
embedding,
annotation_type,
CROPPED_POSITIONS,
fixed_positions,
)
algorithm = Distance()
return VanillaBiometricsPipeline(transformer, algorithm)
pipeline = load(annotation_type, fixed_positions)
transformer = pipeline.transformer
......@@ -10,6 +10,7 @@ from functools import partial
import pkg_resources
import os
from bob.bio.face.embeddings import download_model
import numpy as np
def sanderberg_rescaling():
......@@ -34,13 +35,19 @@ class InceptionResnet(TransformerMixin, BaseEstimator):
preprocessor:
Preprocessor function
memory_demanding bool
If `True`, the `transform` method will run one sample at the time.
This is useful when there is not enough memory available to forward big chucks of data.
"""
def __init__(self, checkpoint_path, preprocessor=None, **kwargs):
def __init__(
self, checkpoint_path, preprocessor=None, memory_demanding=False, **kwargs
):
super().__init__(**kwargs)
self.checkpoint_path = checkpoint_path
self.model = None
self.preprocessor = preprocessor
self.memory_demanding = memory_demanding
def load_model(self):
self.model = tf.keras.models.load_model(self.checkpoint_path)
......@@ -54,19 +61,26 @@ class InceptionResnet(TransformerMixin, BaseEstimator):
return embeddings
def transform(self, X):
def _transform(X):
X = tf.convert_to_tensor(X)
X = to_channels_last(X)
if X.shape[-3:] != self.model.input_shape[-3:]:
raise ValueError(
f"Image shape {X.shape} not supported. Expected {self.model.input_shape}"
)
return self.inference(X).numpy()
if self.model is None:
self.load_model()
X = check_array(X, allow_nd=True)
X = tf.convert_to_tensor(X)
X = to_channels_last(X)
if X.shape[-3:] != self.model.input_shape[-3:]:
raise ValueError(
f"Image shape {X.shape} not supported. Expected {self.model.input_shape}"
)
return self.inference(X).numpy()
if self.memory_demanding:
return np.array([_transform(x[None, ...]) for x in X])
else:
return _transform(X)
def __getstate__(self):
# Handling unpicklable objects
......@@ -89,7 +103,7 @@ class InceptionResnetv2_MsCeleb_CenterLoss_2018(InceptionResnet):
"""
def __init__(self):
def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "inceptionresnetv2_msceleb_centerloss_2018"),
)
......@@ -111,7 +125,9 @@ class InceptionResnetv2_MsCeleb_CenterLoss_2018(InceptionResnet):
)
super(InceptionResnetv2_MsCeleb_CenterLoss_2018, self).__init__(
checkpoint_path, preprocessor=tf.image.per_image_standardization,
checkpoint_path,
preprocessor=tf.image.per_image_standardization,
memory_demanding=memory_demanding,
)
......@@ -123,7 +139,7 @@ class InceptionResnetv2_Casia_CenterLoss_2018(InceptionResnet):
"""
def __init__(self):
def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "inceptionresnetv2_casia_centerloss_2018"),
)
......@@ -144,7 +160,9 @@ class InceptionResnetv2_Casia_CenterLoss_2018(InceptionResnet):
)
super(InceptionResnetv2_Casia_CenterLoss_2018, self).__init__(
checkpoint_path, preprocessor=tf.image.per_image_standardization,
checkpoint_path,
preprocessor=tf.image.per_image_standardization,
memory_demanding=memory_demanding,
)
......@@ -156,7 +174,7 @@ class InceptionResnetv1_Casia_CenterLoss_2018(InceptionResnet):
"""
def __init__(self):
def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "inceptionresnetv1_casia_centerloss_2018"),
)
......@@ -177,7 +195,9 @@ class InceptionResnetv1_Casia_CenterLoss_2018(InceptionResnet):
)
super(InceptionResnetv1_Casia_CenterLoss_2018, self).__init__(
checkpoint_path, preprocessor=tf.image.per_image_standardization,
checkpoint_path,
preprocessor=tf.image.per_image_standardization,
memory_demanding=memory_demanding,
)
......@@ -189,7 +209,7 @@ class InceptionResnetv1_MsCeleb_CenterLoss_2018(InceptionResnet):
"""
def __init__(self):
def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "inceptionresnetv1_msceleb_centerloss_2018"),
)
......@@ -211,7 +231,9 @@ class InceptionResnetv1_MsCeleb_CenterLoss_2018(InceptionResnet):
)
super(InceptionResnetv1_MsCeleb_CenterLoss_2018, self).__init__(
checkpoint_path, preprocessor=tf.image.per_image_standardization,
checkpoint_path,
preprocessor=tf.image.per_image_standardization,
memory_demanding=memory_demanding,
)
......@@ -237,7 +259,7 @@ class FaceNetSanderberg_20170512_110547(InceptionResnet):
)
"""
def __init__(self):
def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "facenet_sanderberg_20170512_110547"),
)
......@@ -257,5 +279,8 @@ class FaceNetSanderberg_20170512_110547(InceptionResnet):
)
super(FaceNetSanderberg_20170512_110547, self).__init__(
checkpoint_path, tf.image.per_image_standardization,
checkpoint_path,
tf.image.per_image_standardization,
memory_demanding=memory_demanding,
)
......@@ -32,6 +32,33 @@ def test_idiap_inceptionv2_msceleb():
assert output.size == 128, output.shape
@is_library_available("tensorflow")
def test_idiap_inceptionv2_msceleb_memory_demanding():
from bob.bio.face.embeddings.tf2_inception_resnet import (
InceptionResnetv2_MsCeleb_CenterLoss_2018,
)
reference = bob.io.base.load(
pkg_resources.resource_filename(
"bob.bio.face.test", "data/inception_resnet_v2_rgb.hdf5"
)
)
np.random.seed(10)
transformer = InceptionResnetv2_MsCeleb_CenterLoss_2018(memory_demanding=True)
data = (np.random.rand(3, 160, 160) * 255).astype("uint8")
output = transformer.transform([data])[0]
assert output.size == 128, output.shape
# Sample Batch
sample = Sample(data)
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
np.testing.assert_allclose(output[0], reference.flatten(), rtol=1e-5, atol=1e-4)
assert output.size == 128, output.shape
@is_library_available("tensorflow")
def test_idiap_inceptionv2_casia():
from bob.bio.face.embeddings import InceptionResnetv2_Casia_CenterLoss_2018
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment