diff --git a/bob/bio/face/config/baseline/arcface_insightface.py b/bob/bio/face/config/baseline/arcface_insightface.py index bec8d2a8225056006814e91c5576f7bbf24696a9..06f071c592a841b6b2ca66ad1997dc86cda518be 100644 --- a/bob/bio/face/config/baseline/arcface_insightface.py +++ b/bob/bio/face/config/baseline/arcface_insightface.py @@ -9,6 +9,9 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( if "database" in locals(): annotation_type = database.annotation_type fixed_positions = database.fixed_positions + memory_demanding = ( + database.memory_demanding if hasattr(database, "memory_demanding") else False + ) else: annotation_type = None fixed_positions = None @@ -16,7 +19,7 @@ else: def load(annotation_type, fixed_positions=None): transformer = embedding_transformer_112x112( - ArcFaceInsightFace(), annotation_type, fixed_positions, color_channel="rgb" + ArcFaceInsightFace(memory_demanding=memory_demanding), annotation_type, fixed_positions, color_channel="rgb" ) algorithm = Distance() diff --git a/bob/bio/face/embeddings/mxnet_models.py b/bob/bio/face/embeddings/mxnet_models.py index 7eca24648284343b71cb0550ad1a25286936e39d..d8eb786491356ca1bfc0b6893d895bb49200125a 100644 --- a/bob/bio/face/embeddings/mxnet_models.py +++ b/bob/bio/face/embeddings/mxnet_models.py @@ -21,10 +21,11 @@ class ArcFaceInsightFace(TransformerMixin, BaseEstimator): """ - def __init__(self, use_gpu=False, **kwargs): + def __init__(self, use_gpu=False, memory_demanding=False, **kwargs): super().__init__(**kwargs) self.model = None self.use_gpu = use_gpu + self.memory_demanding = memory_demanding internal_path = pkg_resources.resource_filename( __name__, os.path.join("data", "arcface_insightface"), @@ -76,10 +77,17 @@ class ArcFaceInsightFace(TransformerMixin, BaseEstimator): self.load_model() X = check_array(X, allow_nd=True) - X = mx.nd.array(X) - db = mx.io.DataBatch(data=(X,)) - self.model.forward(db, is_train=False) - return self.model.get_outputs()[0].asnumpy() + + def _transform(X): + X = mx.nd.array(X) + db = mx.io.DataBatch(data=(X,)) + self.model.forward(db, is_train=False) + return self.model.get_outputs()[0].asnumpy() + + if self.memory_demanding: + return np.array([_transform(x[None, ...]) for x in X]) + else: + return _transform(X) def __getstate__(self): # Handling unpicklable objects