From b894b8299d42af643e0be75d68b30d7d30450163 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Wed, 9 Jun 2021 13:53:22 +0200
Subject: [PATCH] Patched pytorch models

---
 bob/bio/face/embeddings/pytorch.py | 46 +++++++++++++++++++++---------
 1 file changed, 32 insertions(+), 14 deletions(-)

diff --git a/bob/bio/face/embeddings/pytorch.py b/bob/bio/face/embeddings/pytorch.py
index 601ac06a..2007865d 100644
--- a/bob/bio/face/embeddings/pytorch.py
+++ b/bob/bio/face/embeddings/pytorch.py
@@ -44,13 +44,16 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
         checkpoint_path=None,
         config=None,
         preprocessor=lambda x: x / 255,
+        memory_demanding=False,
         **kwargs
     ):
+
         super().__init__(**kwargs)
         self.checkpoint_path = checkpoint_path
         self.config = config
         self.model = None
         self.preprocessor = preprocessor
+        self.memory_demanding = memory_demanding
 
     def transform(self, X):
         """__call__(image) -> feature
@@ -74,7 +77,14 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
         X = check_array(X, allow_nd=True)
         X = torch.Tensor(X)
         X = self.preprocessor(X)
-        return self.model(X).detach().numpy()
+
+        def _transform(X):
+            return self.model(X).detach().numpy()
+
+        if self.memory_demanding:
+            return np.array([_transform(x[None, ...]) for x in X])
+        else:
+            return _transform(X)
 
     def __getstate__(self):
         # Handling unpicklable objects
@@ -93,7 +103,7 @@ class AFFFE_2021(PyTorchModel):
 
     """
 
-    def __init__(self):
+    def __init__(self, memory_demanding=False):
 
         urls = [
             "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/AFFFE-42a53f19.tar.gz",
@@ -111,7 +121,9 @@ class AFFFE_2021(PyTorchModel):
         config = os.path.join(path, "AFFFE.py")
         checkpoint_path = os.path.join(path, "AFFFE.pth")
 
-        super(AFFFE_2021, self).__init__(checkpoint_path, config)
+        super(AFFFE_2021, self).__init__(
+            checkpoint_path, config, memory_demanding=memory_demanding
+        )
 
     def _load_model(self):
 
@@ -148,7 +160,7 @@ class IResnet34(PyTorchModel):
     ArcFace model (RESNET 34) from Insightface ported to pytorch
     """
 
-    def __init__(self):
+    def __init__(self, memory_demanding=False):
 
         urls = [
             "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
@@ -161,7 +173,9 @@ class IResnet34(PyTorchModel):
         config = os.path.join(path, "iresnet.py")
         checkpoint_path = os.path.join(path, "iresnet34-5b0d0e90.pth")
 
-        super(IResnet34, self).__init__(checkpoint_path, config)
+        super(IResnet34, self).__init__(
+            checkpoint_path, config, memory_demanding=memory_demanding
+        )
 
     def _load_model(self):
 
@@ -174,7 +188,7 @@ class IResnet50(PyTorchModel):
     ArcFace model (RESNET 50) from Insightface ported to pytorch
     """
 
-    def __init__(self):
+    def __init__(self, memory_demanding=False):
 
         filename = _get_iresnet_file()
 
@@ -182,7 +196,9 @@ class IResnet50(PyTorchModel):
         config = os.path.join(path, "iresnet.py")
         checkpoint_path = os.path.join(path, "iresnet50-7f187506.pth")
 
-        super(IResnet50, self).__init__(checkpoint_path, config)
+        super(IResnet50, self).__init__(
+            checkpoint_path, config, memory_demanding=memory_demanding
+        )
 
     def _load_model(self):
 
@@ -195,7 +211,7 @@ class IResnet100(PyTorchModel):
     ArcFace model (RESNET 100) from Insightface ported to pytorch
     """
 
-    def __init__(self):
+    def __init__(self, memory_demanding=False):
 
         filename = _get_iresnet_file()
 
@@ -203,7 +219,9 @@ class IResnet100(PyTorchModel):
         config = os.path.join(path, "iresnet.py")
         checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth")
 
-        super(IResnet100, self).__init__(checkpoint_path, config)
+        super(IResnet100, self).__init__(
+            checkpoint_path, config, memory_demanding=memory_demanding
+        )
 
     def _load_model(self):
 
@@ -261,7 +279,7 @@ def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False):
     """
 
     return iresnet_template(
-        embedding=IResnet34(),
+        embedding=IResnet34(memory_demanding=memory_demanding),
         annotation_type=annotation_type,
         fixed_positions=fixed_positions,
     )
@@ -291,7 +309,7 @@ def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False):
     """
 
     return iresnet_template(
-        embedding=IResnet50(),
+        embedding=IResnet50(memory_demanding=memory_demanding),
         annotation_type=annotation_type,
         fixed_positions=fixed_positions,
     )
@@ -321,13 +339,13 @@ def iresnet100(annotation_type, fixed_positions=None, memory_demanding=False):
     """
 
     return iresnet_template(
-        embedding=IResnet100(),
+        embedding=IResnet100(memory_demanding=memory_demanding),
         annotation_type=annotation_type,
         fixed_positions=fixed_positions,
     )
 
 
-def afffe_baseline(annotation_type, fixed_positions=None):
+def afffe_baseline(annotation_type, fixed_positions=None, memory_demanding=False):
     """
     Get the AFFFE pipeline which will crop the face :math:`224 \times 224`
     use the :py:class:`AFFFE_2021`
@@ -353,7 +371,7 @@ def afffe_baseline(annotation_type, fixed_positions=None):
 
     transformer = embedding_transformer(
         cropped_image_size=cropped_image_size,
-        embedding=AFFFE_2021(),
+        embedding=AFFFE_2021(memory_demanding=memory_demanding),
         cropped_positions=cropped_positions,
         fixed_positions=fixed_positions,
         color_channel="rgb",
-- 
GitLab