diff --git a/bob/bio/face/config/baseline/arcface_insightface.py b/bob/bio/face/config/baseline/arcface_insightface.py
index bec8d2a8225056006814e91c5576f7bbf24696a9..06f071c592a841b6b2ca66ad1997dc86cda518be 100644
--- a/bob/bio/face/config/baseline/arcface_insightface.py
+++ b/bob/bio/face/config/baseline/arcface_insightface.py
@@ -9,6 +9,9 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
 if "database" in locals():
     annotation_type = database.annotation_type
     fixed_positions = database.fixed_positions
+    memory_demanding = (
+        database.memory_demanding if hasattr(database, "memory_demanding") else False
+    )
 else:
     annotation_type = None
     fixed_positions = None
@@ -16,7 +19,7 @@ else:
 
 def load(annotation_type, fixed_positions=None):
     transformer = embedding_transformer_112x112(
-        ArcFaceInsightFace(), annotation_type, fixed_positions, color_channel="rgb"
+        ArcFaceInsightFace(memory_demanding=memory_demanding), annotation_type, fixed_positions, color_channel="rgb"
     )
 
     algorithm = Distance()
diff --git a/bob/bio/face/embeddings/mxnet_models.py b/bob/bio/face/embeddings/mxnet_models.py
index 7eca24648284343b71cb0550ad1a25286936e39d..d8eb786491356ca1bfc0b6893d895bb49200125a 100644
--- a/bob/bio/face/embeddings/mxnet_models.py
+++ b/bob/bio/face/embeddings/mxnet_models.py
@@ -21,10 +21,11 @@ class ArcFaceInsightFace(TransformerMixin, BaseEstimator):
 
     """
 
-    def __init__(self, use_gpu=False, **kwargs):
+    def __init__(self, use_gpu=False, memory_demanding=False, **kwargs):
         super().__init__(**kwargs)
         self.model = None
         self.use_gpu = use_gpu
+        self.memory_demanding = memory_demanding
 
         internal_path = pkg_resources.resource_filename(
             __name__, os.path.join("data", "arcface_insightface"),
@@ -76,10 +77,17 @@ class ArcFaceInsightFace(TransformerMixin, BaseEstimator):
             self.load_model()
 
         X = check_array(X, allow_nd=True)
-        X = mx.nd.array(X)
-        db = mx.io.DataBatch(data=(X,))
-        self.model.forward(db, is_train=False)
-        return self.model.get_outputs()[0].asnumpy()
+
+        def _transform(X):
+            X = mx.nd.array(X)
+            db = mx.io.DataBatch(data=(X,))
+            self.model.forward(db, is_train=False)
+            return self.model.get_outputs()[0].asnumpy()
+
+        if self.memory_demanding:
+            return np.array([_transform(x[None, ...]) for x in X])
+        else:
+            return _transform(X)
 
     def __getstate__(self):
         # Handling unpicklable objects