From 99113f1f3874905e50d262a6c2df49a451b43728 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Tue, 8 Jun 2021 17:53:22 +0200
Subject: [PATCH] Added some arcface pytorch models

---
 bob/bio/face/config/baseline/iresnet100.py |  15 ++
 bob/bio/face/config/baseline/iresnet34.py  |  15 ++
 bob/bio/face/config/baseline/iresnet50.py  |  15 ++
 bob/bio/face/embeddings/pytorch.py         | 200 ++++++++++++++++++++-
 bob/bio/face/test/test_baselines.py        |  27 ++-
 setup.py                                   |   6 +
 6 files changed, 275 insertions(+), 3 deletions(-)
 create mode 100644 bob/bio/face/config/baseline/iresnet100.py
 create mode 100644 bob/bio/face/config/baseline/iresnet34.py
 create mode 100644 bob/bio/face/config/baseline/iresnet50.py

diff --git a/bob/bio/face/config/baseline/iresnet100.py b/bob/bio/face/config/baseline/iresnet100.py
new file mode 100644
index 00000000..6148d162
--- /dev/null
+++ b/bob/bio/face/config/baseline/iresnet100.py
@@ -0,0 +1,15 @@
+from bob.bio.face.embeddings.pytorch import iresnet100
+from bob.bio.face.utils import lookup_config_from_database
+
+
+annotation_type, fixed_positions, _ = lookup_config_from_database(
+    locals().get("database")
+)
+
+
+def load(annotation_type, fixed_positions=None):
+    return iresnet100(annotation_type, fixed_positions)
+
+
+pipeline = load(annotation_type, fixed_positions)
+
diff --git a/bob/bio/face/config/baseline/iresnet34.py b/bob/bio/face/config/baseline/iresnet34.py
new file mode 100644
index 00000000..f66b2b88
--- /dev/null
+++ b/bob/bio/face/config/baseline/iresnet34.py
@@ -0,0 +1,15 @@
+from bob.bio.face.embeddings.pytorch import iresnet34
+from bob.bio.face.utils import lookup_config_from_database
+
+
+annotation_type, fixed_positions, _ = lookup_config_from_database(
+    locals().get("database")
+)
+
+
+def load(annotation_type, fixed_positions=None):
+    return iresnet34(annotation_type, fixed_positions)
+
+
+pipeline = load(annotation_type, fixed_positions)
+
diff --git a/bob/bio/face/config/baseline/iresnet50.py b/bob/bio/face/config/baseline/iresnet50.py
new file mode 100644
index 00000000..ae256242
--- /dev/null
+++ b/bob/bio/face/config/baseline/iresnet50.py
@@ -0,0 +1,15 @@
+from bob.bio.face.embeddings.pytorch import iresnet50
+from bob.bio.face.utils import lookup_config_from_database
+
+
+annotation_type, fixed_positions, _ = lookup_config_from_database(
+    locals().get("database")
+)
+
+
+def load(annotation_type, fixed_positions=None):
+    return iresnet50(annotation_type, fixed_positions)
+
+
+pipeline = load(annotation_type, fixed_positions)
+
diff --git a/bob/bio/face/embeddings/pytorch.py b/bob/bio/face/embeddings/pytorch.py
index 3e016b4a..601ac06a 100644
--- a/bob/bio/face/embeddings/pytorch.py
+++ b/bob/bio/face/embeddings/pytorch.py
@@ -74,7 +74,6 @@ class PyTorchModel(TransformerMixin, BaseEstimator):
         X = check_array(X, allow_nd=True)
         X = torch.Tensor(X)
         X = self.preprocessor(X)
-
         return self.model(X).detach().numpy()
 
     def __getstate__(self):
@@ -129,6 +128,205 @@ class AFFFE_2021(PyTorchModel):
         self.model = network
 
 
+def _get_iresnet_file():
+    urls = [
+        "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
+        "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
+    ]
+
+    return get_file(
+        "iresnet-91a5de61.tar.gz",
+        urls,
+        cache_subdir="data/pytorch/iresnet-91a5de61/",
+        file_hash="3976c0a539811d888ef5b6217e5de425",
+        extract=True,
+    )
+
+
+class IResnet34(PyTorchModel):
+    """
+    ArcFace model (RESNET 34) from Insightface ported to pytorch
+    """
+
+    def __init__(self):
+
+        urls = [
+            "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
+            "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
+        ]
+
+        filename = _get_iresnet_file()
+
+        path = os.path.dirname(filename)
+        config = os.path.join(path, "iresnet.py")
+        checkpoint_path = os.path.join(path, "iresnet34-5b0d0e90.pth")
+
+        super(IResnet34, self).__init__(checkpoint_path, config)
+
+    def _load_model(self):
+
+        model = imp.load_source("module", self.config).iresnet34(self.checkpoint_path)
+        self.model = model
+
+
+class IResnet50(PyTorchModel):
+    """
+    ArcFace model (RESNET 50) from Insightface ported to pytorch
+    """
+
+    def __init__(self):
+
+        filename = _get_iresnet_file()
+
+        path = os.path.dirname(filename)
+        config = os.path.join(path, "iresnet.py")
+        checkpoint_path = os.path.join(path, "iresnet50-7f187506.pth")
+
+        super(IResnet50, self).__init__(checkpoint_path, config)
+
+    def _load_model(self):
+
+        model = imp.load_source("module", self.config).iresnet50(self.checkpoint_path)
+        self.model = model
+
+
+class IResnet100(PyTorchModel):
+    """
+    ArcFace model (RESNET 100) from Insightface ported to pytorch
+    """
+
+    def __init__(self):
+
+        filename = _get_iresnet_file()
+
+        path = os.path.dirname(filename)
+        config = os.path.join(path, "iresnet.py")
+        checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth")
+
+        super(IResnet100, self).__init__(checkpoint_path, config)
+
+    def _load_model(self):
+
+        model = imp.load_source("module", self.config).iresnet100(self.checkpoint_path)
+        self.model = model
+
+
+def iresnet_template(embedding, annotation_type, fixed_positions=None):
+    # DEFINE CROPPING
+    cropped_image_size = (112, 112)
+    if annotation_type == "eyes-center":
+        # Hard coding eye positions for backward consistency
+        cropped_positions = {
+            "leye": (55, 81),
+            "reye": (55, 42),
+        }
+    else:
+        cropped_positions = dnn_default_cropping(cropped_image_size, annotation_type)
+
+    transformer = embedding_transformer(
+        cropped_image_size=cropped_image_size,
+        embedding=embedding,
+        cropped_positions=cropped_positions,
+        fixed_positions=fixed_positions,
+        color_channel="rgb",
+        annotator="mtcnn",
+    )
+
+    algorithm = Distance()
+
+    return VanillaBiometricsPipeline(transformer, algorithm)
+
+
+def iresnet34(annotation_type, fixed_positions=None, memory_demanding=False):
+    """
+    Get the Resnet34 pipeline which will crop the face :math:`112 \times 112` and 
+    use the :py:class:`IResnet34` to extract the features
+
+
+    code referenced from https://raw.githubusercontent.com/nizhib/pytorch-insightface/master/insightface/iresnet.py
+    https://github.com/nizhib/pytorch-insightface
+
+
+    Parameters
+    ----------
+
+      annotation_type: str
+         Type of the annotations (e.g. `eyes-center')
+
+      fixed_positions: dict
+         Set it if in your face images are registered to a fixed position in the image
+
+      memory_demanding: bool
+
+    """
+
+    return iresnet_template(
+        embedding=IResnet34(),
+        annotation_type=annotation_type,
+        fixed_positions=fixed_positions,
+    )
+
+
+def iresnet50(annotation_type, fixed_positions=None, memory_demanding=False):
+    """
+    Get the Resnet50 pipeline which will crop the face :math:`112 \times 112` and 
+    use the :py:class:`IResnet50` to extract the features
+
+
+    code referenced from https://raw.githubusercontent.com/nizhib/pytorch-insightface/master/insightface/iresnet.py
+    https://github.com/nizhib/pytorch-insightface
+
+
+    Parameters
+    ----------
+
+      annotation_type: str
+         Type of the annotations (e.g. `eyes-center')
+
+      fixed_positions: dict
+         Set it if in your face images are registered to a fixed position in the image
+
+      memory_demanding: bool
+
+    """
+
+    return iresnet_template(
+        embedding=IResnet50(),
+        annotation_type=annotation_type,
+        fixed_positions=fixed_positions,
+    )
+
+
+def iresnet100(annotation_type, fixed_positions=None, memory_demanding=False):
+    """
+    Get the Resnet100 pipeline which will crop the face :math:`112 \times 112` and 
+    use the :py:class:`IResnet100` to extract the features
+
+
+    code referenced from https://raw.githubusercontent.com/nizhib/pytorch-insightface/master/insightface/iresnet.py
+    https://github.com/nizhib/pytorch-insightface
+
+
+    Parameters
+    ----------
+
+      annotation_type: str
+         Type of the annotations (e.g. `eyes-center')
+
+      fixed_positions: dict
+         Set it if in your face images are registered to a fixed position in the image
+
+      memory_demanding: bool
+
+    """
+
+    return iresnet_template(
+        embedding=IResnet100(),
+        annotation_type=annotation_type,
+        fixed_positions=fixed_positions,
+    )
+
+
 def afffe_baseline(annotation_type, fixed_positions=None):
     """
     Get the AFFFE pipeline which will crop the face :math:`224 \times 224`
diff --git a/bob/bio/face/test/test_baselines.py b/bob/bio/face/test/test_baselines.py
index aa2fdf98..fca4678a 100644
--- a/bob/bio/face/test/test_baselines.py
+++ b/bob/bio/face/test/test_baselines.py
@@ -66,7 +66,6 @@ def run_baseline(baseline, samples_for_training=[], target_scores=None):
     # Regular pipeline
     pipeline = load_resource(baseline, "pipeline")
     scores = pipeline(samples_for_training, biometric_references, probes)
-
     assert len(scores) == 1
     assert len(scores[0]) == 1
 
@@ -81,7 +80,7 @@ def run_baseline(baseline, samples_for_training=[], target_scores=None):
         assert len(checkpoint_scores[0]) == 1
 
         if target_scores is not None:
-            assert np.allclose(target_scores, scores[0][0].data, atol=10e-3, rtol=10e-3)
+            assert np.allclose(target_scores, scores[0][0].data, atol=10e-5, rtol=10e-5)
 
         assert np.isclose(scores[0][0].data, checkpoint_scores[0][0].data)
 
@@ -175,6 +174,30 @@ def test_afffe():
     )
 
 
+@pytest.mark.slow
+@is_library_available("torch")
+def test_iresnet34():
+    run_baseline(
+        "iresnet34", target_scores=-0.0003085132478504171,
+    )
+
+
+@pytest.mark.slow
+@is_library_available("torch")
+def test_iresnet50():
+    run_baseline(
+        "iresnet50", target_scores=-0.0013965432856760662,
+    )
+
+
+@pytest.mark.slow
+@is_library_available("torch")
+def test_iresnet100():
+    run_baseline(
+        "iresnet100", target_scores=-0.0002386926047015514,
+    )
+
+
 @pytest.mark.slow
 @is_library_available("cv2")
 def test_vgg16_oxford():
diff --git a/setup.py b/setup.py
index e877cae0..12eb3529 100644
--- a/setup.py
+++ b/setup.py
@@ -139,6 +139,9 @@ setup(
             "mxnet-tinyface = bob.bio.face.config.baseline.mxnet_tinyface:pipeline",
             "afffe = bob.bio.face.config.baseline.afffe:pipeline",
             "vgg16-oxford = bob.bio.face.config.baseline.vgg16_oxford:pipeline",
+            "iresnet34 = bob.bio.face.config.baseline.iresnet34:pipeline",
+            "iresnet50 = bob.bio.face.config.baseline.iresnet50:pipeline",
+            "iresnet100 = bob.bio.face.config.baseline.iresnet100:pipeline",
         ],
         "bob.bio.config": [
             "facenet-sanderberg = bob.bio.face.config.baseline.facenet_sanderberg",
@@ -175,6 +178,9 @@ setup(
             "resnet50-msceleb-arcface-2021 = bob.bio.face.config.baseline.resnet50_msceleb_arcface_2021",
             "resnet50-vgg2-arcface-2021 = bob.bio.face.config.baseline.resnet50_vgg2_arcface_2021",
             "mobilenetv2-msceleb-arcface-2021 = bob.bio.face.config.baseline.mobilenetv2_msceleb_arcface_2021",
+            "iresnet34 = bob.bio.face.config.baseline.iresnet34",
+            "iresnet50 = bob.bio.face.config.baseline.iresnet50",
+            "iresnet100 = bob.bio.face.config.baseline.iresnet100",
         ],
         "bob.bio.cli": [
             "display-face-annotations          = bob.bio.face.script.display_face_annotations:display_face_annotations",
-- 
GitLab