Skip to content
Snippets Groups Projects
Commit a079dc70 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Add a memory demanding feature in the TF based transformers

.
parent d4fbeef1
No related branches found
No related tags found
1 merge request!83`memory_demanding` for TF based transformers
Pipeline #46033 passed
...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline, VanillaBiometricsPipeline,
) )
memory_demanding = False
if "database" in locals(): if "database" in locals():
annotation_type = database.annotation_type annotation_type = database.annotation_type
fixed_positions = database.fixed_positions fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else: else:
annotation_type = None annotation_type = None
fixed_positions = None fixed_positions = None
...@@ -19,7 +23,9 @@ else: ...@@ -19,7 +23,9 @@ else:
def load(annotation_type, fixed_positions=None): def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160( transformer = embedding_transformer_160x160(
FaceNetSanderberg_20170512_110547(), annotation_type, fixed_positions FaceNetSanderberg_20170512_110547(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
) )
algorithm = Distance() algorithm = Distance()
......
...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline, VanillaBiometricsPipeline,
) )
memory_demanding = False
if "database" in locals(): if "database" in locals():
annotation_type = database.annotation_type annotation_type = database.annotation_type
fixed_positions = database.fixed_positions fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else: else:
annotation_type = None annotation_type = None
fixed_positions = None fixed_positions = None
...@@ -18,7 +22,9 @@ else: ...@@ -18,7 +22,9 @@ else:
def load(annotation_type, fixed_positions=None): def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160( transformer = embedding_transformer_160x160(
InceptionResnetv1_Casia_CenterLoss_2018(), annotation_type, fixed_positions InceptionResnetv1_Casia_CenterLoss_2018(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
) )
algorithm = Distance() algorithm = Distance()
......
...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline, VanillaBiometricsPipeline,
) )
memory_demanding = False
if "database" in locals(): if "database" in locals():
annotation_type = database.annotation_type annotation_type = database.annotation_type
fixed_positions = database.fixed_positions fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else: else:
annotation_type = None annotation_type = None
fixed_positions = None fixed_positions = None
...@@ -18,7 +22,9 @@ else: ...@@ -18,7 +22,9 @@ else:
def load(annotation_type, fixed_positions=None): def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160( transformer = embedding_transformer_160x160(
InceptionResnetv1_MsCeleb_CenterLoss_2018(), annotation_type, fixed_positions InceptionResnetv1_MsCeleb_CenterLoss_2018(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
) )
algorithm = Distance() algorithm = Distance()
......
...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline, VanillaBiometricsPipeline,
) )
memory_demanding = False
if "database" in locals(): if "database" in locals():
annotation_type = database.annotation_type annotation_type = database.annotation_type
fixed_positions = database.fixed_positions fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else: else:
annotation_type = None annotation_type = None
fixed_positions = None fixed_positions = None
...@@ -18,7 +22,9 @@ else: ...@@ -18,7 +22,9 @@ else:
def load(annotation_type, fixed_positions=None): def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160( transformer = embedding_transformer_160x160(
InceptionResnetv2_Casia_CenterLoss_2018(), annotation_type, fixed_positions InceptionResnetv2_Casia_CenterLoss_2018(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
) )
algorithm = Distance() algorithm = Distance()
......
...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -7,10 +7,14 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline, VanillaBiometricsPipeline,
) )
memory_demanding = False
if "database" in locals(): if "database" in locals():
annotation_type = database.annotation_type annotation_type = database.annotation_type
fixed_positions = database.fixed_positions fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else: else:
annotation_type = None annotation_type = None
fixed_positions = None fixed_positions = None
...@@ -18,7 +22,9 @@ else: ...@@ -18,7 +22,9 @@ else:
def load(annotation_type, fixed_positions=None): def load(annotation_type, fixed_positions=None):
transformer = embedding_transformer_160x160( transformer = embedding_transformer_160x160(
InceptionResnetv2_MsCeleb_CenterLoss_2018(), annotation_type, fixed_positions InceptionResnetv2_MsCeleb_CenterLoss_2018(memory_demanding=memory_demanding),
annotation_type,
fixed_positions,
) )
algorithm = Distance() algorithm = Distance()
......
...@@ -3,7 +3,7 @@ from bob.bio.face.embeddings.tf2_inception_resnet import InceptionResnetv2 ...@@ -3,7 +3,7 @@ from bob.bio.face.embeddings.tf2_inception_resnet import InceptionResnetv2
from bob.bio.face.preprocessor import FaceCrop from bob.bio.face.preprocessor import FaceCrop
from bob.bio.face.config.baseline.helpers import ( from bob.bio.face.config.baseline.helpers import (
embedding_transformer_default_cropping, embedding_transformer_default_cropping,
embedding_transformer embedding_transformer,
) )
from sklearn.pipeline import make_pipeline from sklearn.pipeline import make_pipeline
...@@ -13,32 +13,43 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -13,32 +13,43 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
VanillaBiometricsPipeline, VanillaBiometricsPipeline,
) )
memory_demanding = False
if "database" in locals(): if "database" in locals():
annotation_type = database.annotation_type annotation_type = database.annotation_type
fixed_positions = database.fixed_positions fixed_positions = database.fixed_positions
memory_demanding = (
database.memory_demanding if hasattr(database, "memory_demanding") else False
)
else: else:
annotation_type = None annotation_type = None
fixed_positions = None fixed_positions = None
def load(annotation_type, fixed_positions=None): def load(annotation_type, fixed_positions=None):
CROPPED_IMAGE_SIZE = (160, 160) CROPPED_IMAGE_SIZE = (160, 160)
CROPPED_POSITIONS = embedding_transformer_default_cropping(CROPPED_IMAGE_SIZE, CROPPED_POSITIONS = embedding_transformer_default_cropping(
annotation_type=annotation_type) CROPPED_IMAGE_SIZE, annotation_type=annotation_type
)
extractor_path = rc['bob.bio.face.tf2.casia-webface-inception-v2'] extractor_path = rc["bob.bio.face.tf2.casia-webface-inception-v2"]
embedding = InceptionResnetv2(checkpoint_path=extractor_path) embedding = InceptionResnetv2(
checkpoint_path=extractor_path, memory_demanding=memory_demanding
)
transformer = embedding_transformer(CROPPED_IMAGE_SIZE, transformer = embedding_transformer(
embedding, CROPPED_IMAGE_SIZE,
annotation_type, embedding,
CROPPED_POSITIONS, annotation_type,
fixed_positions) CROPPED_POSITIONS,
fixed_positions,
)
algorithm = Distance() algorithm = Distance()
return VanillaBiometricsPipeline(transformer, algorithm) return VanillaBiometricsPipeline(transformer, algorithm)
pipeline = load(annotation_type, fixed_positions) pipeline = load(annotation_type, fixed_positions)
transformer = pipeline.transformer transformer = pipeline.transformer
...@@ -10,6 +10,7 @@ from functools import partial ...@@ -10,6 +10,7 @@ from functools import partial
import pkg_resources import pkg_resources
import os import os
from bob.bio.face.embeddings import download_model from bob.bio.face.embeddings import download_model
import numpy as np
def sanderberg_rescaling(): def sanderberg_rescaling():
...@@ -34,13 +35,19 @@ class InceptionResnet(TransformerMixin, BaseEstimator): ...@@ -34,13 +35,19 @@ class InceptionResnet(TransformerMixin, BaseEstimator):
preprocessor: preprocessor:
Preprocessor function Preprocessor function
memory_demanding bool
If `True`, the `transform` method will run one sample at the time.
This is useful when there is not enough memory available to forward big chucks of data.
""" """
def __init__(self, checkpoint_path, preprocessor=None, **kwargs): def __init__(
self, checkpoint_path, preprocessor=None, memory_demanding=False, **kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.model = None self.model = None
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.memory_demanding = memory_demanding
def load_model(self): def load_model(self):
self.model = tf.keras.models.load_model(self.checkpoint_path) self.model = tf.keras.models.load_model(self.checkpoint_path)
...@@ -54,19 +61,26 @@ class InceptionResnet(TransformerMixin, BaseEstimator): ...@@ -54,19 +61,26 @@ class InceptionResnet(TransformerMixin, BaseEstimator):
return embeddings return embeddings
def transform(self, X): def transform(self, X):
def _transform(X):
X = tf.convert_to_tensor(X)
X = to_channels_last(X)
if X.shape[-3:] != self.model.input_shape[-3:]:
raise ValueError(
f"Image shape {X.shape} not supported. Expected {self.model.input_shape}"
)
return self.inference(X).numpy()
if self.model is None: if self.model is None:
self.load_model() self.load_model()
X = check_array(X, allow_nd=True) X = check_array(X, allow_nd=True)
X = tf.convert_to_tensor(X)
X = to_channels_last(X)
if X.shape[-3:] != self.model.input_shape[-3:]: if self.memory_demanding:
raise ValueError( return np.array([_transform(x[None, ...]) for x in X])
f"Image shape {X.shape} not supported. Expected {self.model.input_shape}" else:
) return _transform(X)
return self.inference(X).numpy()
def __getstate__(self): def __getstate__(self):
# Handling unpicklable objects # Handling unpicklable objects
...@@ -89,7 +103,7 @@ class InceptionResnetv2_MsCeleb_CenterLoss_2018(InceptionResnet): ...@@ -89,7 +103,7 @@ class InceptionResnetv2_MsCeleb_CenterLoss_2018(InceptionResnet):
""" """
def __init__(self): def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename( internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "inceptionresnetv2_msceleb_centerloss_2018"), __name__, os.path.join("data", "inceptionresnetv2_msceleb_centerloss_2018"),
) )
...@@ -111,7 +125,9 @@ class InceptionResnetv2_MsCeleb_CenterLoss_2018(InceptionResnet): ...@@ -111,7 +125,9 @@ class InceptionResnetv2_MsCeleb_CenterLoss_2018(InceptionResnet):
) )
super(InceptionResnetv2_MsCeleb_CenterLoss_2018, self).__init__( super(InceptionResnetv2_MsCeleb_CenterLoss_2018, self).__init__(
checkpoint_path, preprocessor=tf.image.per_image_standardization, checkpoint_path,
preprocessor=tf.image.per_image_standardization,
memory_demanding=memory_demanding,
) )
...@@ -123,7 +139,7 @@ class InceptionResnetv2_Casia_CenterLoss_2018(InceptionResnet): ...@@ -123,7 +139,7 @@ class InceptionResnetv2_Casia_CenterLoss_2018(InceptionResnet):
""" """
def __init__(self): def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename( internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "inceptionresnetv2_casia_centerloss_2018"), __name__, os.path.join("data", "inceptionresnetv2_casia_centerloss_2018"),
) )
...@@ -144,7 +160,9 @@ class InceptionResnetv2_Casia_CenterLoss_2018(InceptionResnet): ...@@ -144,7 +160,9 @@ class InceptionResnetv2_Casia_CenterLoss_2018(InceptionResnet):
) )
super(InceptionResnetv2_Casia_CenterLoss_2018, self).__init__( super(InceptionResnetv2_Casia_CenterLoss_2018, self).__init__(
checkpoint_path, preprocessor=tf.image.per_image_standardization, checkpoint_path,
preprocessor=tf.image.per_image_standardization,
memory_demanding=memory_demanding,
) )
...@@ -156,7 +174,7 @@ class InceptionResnetv1_Casia_CenterLoss_2018(InceptionResnet): ...@@ -156,7 +174,7 @@ class InceptionResnetv1_Casia_CenterLoss_2018(InceptionResnet):
""" """
def __init__(self): def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename( internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "inceptionresnetv1_casia_centerloss_2018"), __name__, os.path.join("data", "inceptionresnetv1_casia_centerloss_2018"),
) )
...@@ -177,7 +195,9 @@ class InceptionResnetv1_Casia_CenterLoss_2018(InceptionResnet): ...@@ -177,7 +195,9 @@ class InceptionResnetv1_Casia_CenterLoss_2018(InceptionResnet):
) )
super(InceptionResnetv1_Casia_CenterLoss_2018, self).__init__( super(InceptionResnetv1_Casia_CenterLoss_2018, self).__init__(
checkpoint_path, preprocessor=tf.image.per_image_standardization, checkpoint_path,
preprocessor=tf.image.per_image_standardization,
memory_demanding=memory_demanding,
) )
...@@ -189,7 +209,7 @@ class InceptionResnetv1_MsCeleb_CenterLoss_2018(InceptionResnet): ...@@ -189,7 +209,7 @@ class InceptionResnetv1_MsCeleb_CenterLoss_2018(InceptionResnet):
""" """
def __init__(self): def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename( internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "inceptionresnetv1_msceleb_centerloss_2018"), __name__, os.path.join("data", "inceptionresnetv1_msceleb_centerloss_2018"),
) )
...@@ -211,7 +231,9 @@ class InceptionResnetv1_MsCeleb_CenterLoss_2018(InceptionResnet): ...@@ -211,7 +231,9 @@ class InceptionResnetv1_MsCeleb_CenterLoss_2018(InceptionResnet):
) )
super(InceptionResnetv1_MsCeleb_CenterLoss_2018, self).__init__( super(InceptionResnetv1_MsCeleb_CenterLoss_2018, self).__init__(
checkpoint_path, preprocessor=tf.image.per_image_standardization, checkpoint_path,
preprocessor=tf.image.per_image_standardization,
memory_demanding=memory_demanding,
) )
...@@ -237,7 +259,7 @@ class FaceNetSanderberg_20170512_110547(InceptionResnet): ...@@ -237,7 +259,7 @@ class FaceNetSanderberg_20170512_110547(InceptionResnet):
) )
""" """
def __init__(self): def __init__(self, memory_demanding=False):
internal_path = pkg_resources.resource_filename( internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "facenet_sanderberg_20170512_110547"), __name__, os.path.join("data", "facenet_sanderberg_20170512_110547"),
) )
...@@ -257,5 +279,8 @@ class FaceNetSanderberg_20170512_110547(InceptionResnet): ...@@ -257,5 +279,8 @@ class FaceNetSanderberg_20170512_110547(InceptionResnet):
) )
super(FaceNetSanderberg_20170512_110547, self).__init__( super(FaceNetSanderberg_20170512_110547, self).__init__(
checkpoint_path, tf.image.per_image_standardization, checkpoint_path,
tf.image.per_image_standardization,
memory_demanding=memory_demanding,
) )
...@@ -32,6 +32,33 @@ def test_idiap_inceptionv2_msceleb(): ...@@ -32,6 +32,33 @@ def test_idiap_inceptionv2_msceleb():
assert output.size == 128, output.shape assert output.size == 128, output.shape
@is_library_available("tensorflow")
def test_idiap_inceptionv2_msceleb_memory_demanding():
from bob.bio.face.embeddings.tf2_inception_resnet import (
InceptionResnetv2_MsCeleb_CenterLoss_2018,
)
reference = bob.io.base.load(
pkg_resources.resource_filename(
"bob.bio.face.test", "data/inception_resnet_v2_rgb.hdf5"
)
)
np.random.seed(10)
transformer = InceptionResnetv2_MsCeleb_CenterLoss_2018(memory_demanding=True)
data = (np.random.rand(3, 160, 160) * 255).astype("uint8")
output = transformer.transform([data])[0]
assert output.size == 128, output.shape
# Sample Batch
sample = Sample(data)
transformer_sample = wrap(["sample"], transformer)
output = [s.data for s in transformer_sample.transform([sample])][0]
np.testing.assert_allclose(output[0], reference.flatten(), rtol=1e-5, atol=1e-4)
assert output.size == 128, output.shape
@is_library_available("tensorflow") @is_library_available("tensorflow")
def test_idiap_inceptionv2_casia(): def test_idiap_inceptionv2_casia():
from bob.bio.face.embeddings import InceptionResnetv2_Casia_CenterLoss_2018 from bob.bio.face.embeddings import InceptionResnetv2_Casia_CenterLoss_2018
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment