diff --git a/bob/bio/face/annotator/__init__.py b/bob/bio/face/annotator/__init__.py
index 600daa739fc6ccb8e34cdc3c2f4e23516d6e2093..8e778f616503380d8d88c266d036e08d1355a96b 100644
--- a/bob/bio/face/annotator/__init__.py
+++ b/bob/bio/face/annotator/__init__.py
@@ -58,6 +58,7 @@ from .Base import Base
 from .bobipfacedetect import BobIpFacedetect
 from .bobipflandmark import BobIpFlandmark
 from .bobipmtcnn import BobIpMTCNN
+from .bobiptinyface import BobIpTinyface
 
 
 # gets sphinx autodoc done right - don't remove it
@@ -84,6 +85,7 @@ __appropriate__(
     BobIpFacedetect,
     BobIpFlandmark,
     BobIpMTCNN,
+    BobIpTinyface,
 )
 
 __all__ = [_ for _ in dir() if not _.startswith('_')]
diff --git a/bob/bio/face/annotator/bobiptinyface.py b/bob/bio/face/annotator/bobiptinyface.py
new file mode 100644
index 0000000000000000000000000000000000000000..595f784986dd9aa2ad7444980c3aa863429ecfc5
--- /dev/null
+++ b/bob/bio/face/annotator/bobiptinyface.py
@@ -0,0 +1,39 @@
+import bob.ip.facedetect.tinyface
+from . import Base
+import cv2 as cv
+
+class BobIpTinyface(Base):
+    """Annotator using tinyface in bob.ip.facedetect"""
+    
+    def __init__(self, **kwargs):
+        super(BobIpTinyface, self).__init__(**kwargs)
+        self.tinyface = bob.ip.facedetect.tinyface.TinyFacesDetector(prob_thresh=0.5)
+
+
+    def annotate(self, image, **kwargs):
+        """Annotates an image using tinyface
+
+        Parameters
+        ----------
+        image : numpy.array
+            An RGB image in Bob format.
+        **kwargs
+            Ignored.
+
+        Returns
+        -------
+        dict
+            Annotations with (topleft, bottomright) keys (or None).
+        """
+  
+        annotations = self.tinyface.detect(image)
+
+
+        if annotations is not None:
+            r = annotations[0]
+            return {'topleft':(r[0],r[1]), 'bottomright':(r[2],r[3])}
+        else:
+            return None
+
+
+
diff --git a/bob/bio/face/config/annotator/tinyface.py b/bob/bio/face/config/annotator/tinyface.py
new file mode 100644
index 0000000000000000000000000000000000000000..2274bfeff38e836c34e34cb1846915b58a6d769e
--- /dev/null
+++ b/bob/bio/face/config/annotator/tinyface.py
@@ -0,0 +1,3 @@
+from bob.bio.face.annotator import BobIpTinyface
+
+annotator = BobIpTinyface()
diff --git a/bob/bio/face/config/baseline/mxnet_pipe.py b/bob/bio/face/config/baseline/mxnet_pipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..b458dc70f8bd61c4d75cdee26e2227cc5737cc63
--- /dev/null
+++ b/bob/bio/face/config/baseline/mxnet_pipe.py
@@ -0,0 +1,54 @@
+import bob.bio.base
+from bob.bio.face.preprocessor import FaceCrop
+from bob.bio.face.extractor import mxnet_model
+from bob.bio.base.algorithm import Distance
+from bob.bio.base.pipelines.vanilla_biometrics.legacy import BioAlgorithmLegacy
+import scipy.spatial
+from bob.bio.base.pipelines.vanilla_biometrics import Distance
+from sklearn.pipeline import make_pipeline
+from bob.pipelines import wrap
+from bob.bio.base.pipelines.vanilla_biometrics import VanillaBiometricsPipeline
+
+
+memory_demanding = False
+if "database" in locals():
+    annotation_type = database.annotation_type
+    fixed_positions = database.fixed_positions
+    memory_demanding = (
+        database.memory_demanding if hasattr(database, "memory_demanding") else False
+    )
+
+else:
+    annotation_type = None
+    fixed_positions = None
+
+
+
+cropped_positions={'leye':(49,72), 'reye':(49,38)}
+
+preprocessor_transformer = FaceCrop(cropped_image_size=(112,112), cropped_positions={'leye':(49,72), 'reye':(49,38)}, color_channel='rgb',fixed_positions=fixed_positions)
+
+transform_extra_arguments = (None if (cropped_positions is None or fixed_positions is not None) else (("annotations", "annotations"),))
+
+
+
+
+
+extractor_transformer = mxnet_model()
+
+algorithm = Distance(distance_function = scipy.spatial.distance.cosine,is_distance_function = True)
+
+
+# Chain the Transformers together
+transformer = make_pipeline(
+    wrap(["sample"], preprocessor_transformer,transform_extra_arguments=transform_extra_arguments),
+    wrap(["sample"], extractor_transformer)
+    # Add more transformers here if needed
+)
+
+
+# Assemble the Vanilla Biometric pipeline and execute
+pipeline = VanillaBiometricsPipeline(transformer, algorithm)
+transformer = pipeline.transformer
+
+
diff --git a/bob/bio/face/config/baseline/opencv_pipe.py b/bob/bio/face/config/baseline/opencv_pipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7f8cd8cf005f81900f56ee354086065ede667e4
--- /dev/null
+++ b/bob/bio/face/config/baseline/opencv_pipe.py
@@ -0,0 +1,57 @@
+import bob.bio.base
+from bob.bio.face.preprocessor import FaceCrop
+from bob.bio.base.transformers.preprocessor import PreprocessorTransformer
+from bob.bio.face.extractor import opencv_model
+from bob.bio.base.extractor import Extractor
+from bob.bio.base.transformers import ExtractorTransformer
+from bob.bio.base.algorithm import Distance
+from bob.bio.base.pipelines.vanilla_biometrics.legacy import BioAlgorithmLegacy
+import scipy.spatial
+from bob.bio.base.pipelines.vanilla_biometrics import Distance
+from sklearn.pipeline import make_pipeline
+from bob.pipelines import wrap
+from bob.bio.base.pipelines.vanilla_biometrics import VanillaBiometricsPipeline
+
+
+memory_demanding = False
+if "database" in locals():
+    annotation_type = database.annotation_type
+    fixed_positions = database.fixed_positions
+    memory_demanding = (
+        database.memory_demanding if hasattr(database, "memory_demanding") else False
+    )
+
+else:
+    annotation_type = None
+    fixed_positions = None
+
+
+
+cropped_positions={"leye": (98, 144), "reye": (98, 76)}
+#Preprocessor
+preprocessor_transformer = FaceCrop(cropped_image_size=(224,224), cropped_positions={"leye": (98, 144), "reye": (98, 76)}, color_channel='rgb',fixed_positions=fixed_positions)
+
+transform_extra_arguments = (None if (cropped_positions is None or fixed_positions is not None) else (("annotations", "annotations"),))
+
+
+#Extractor
+extractor_transformer = opencv_model()
+
+
+#Algorithm
+algorithm = Distance(distance_function = scipy.spatial.distance.cosine,is_distance_function = True)
+
+## Creation of the pipeline
+
+
+# Chain the Transformers together
+transformer = make_pipeline(
+    wrap(["sample"], preprocessor_transformer,transform_extra_arguments=transform_extra_arguments),
+    wrap(["sample"], extractor_transformer)
+    # Add more transformers here if needed
+)
+
+
+# Assemble the Vanilla Biometric pipeline and execute
+pipeline = VanillaBiometricsPipeline(transformer, algorithm)
+transformer = pipeline.transformer
diff --git a/bob/bio/face/config/baseline/pytorch_pipe_v1.py b/bob/bio/face/config/baseline/pytorch_pipe_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b03a70dd1934073e3541f423d649c4c492af138
--- /dev/null
+++ b/bob/bio/face/config/baseline/pytorch_pipe_v1.py
@@ -0,0 +1,107 @@
+import bob.bio.base
+from bob.bio.face.preprocessor import FaceCrop
+from bob.bio.face.extractor import pytorch_loaded_model
+from bob.bio.base.algorithm import Distance
+from bob.bio.base.pipelines.vanilla_biometrics.legacy import BioAlgorithmLegacy
+import scipy.spatial
+from bob.bio.base.pipelines.vanilla_biometrics import Distance
+
+from sklearn.pipeline import make_pipeline
+from bob.pipelines import wrap
+from bob.bio.base.pipelines.vanilla_biometrics import VanillaBiometricsPipeline
+
+memory_demanding = False
+if "database" in locals():
+    annotation_type = database.annotation_type
+    fixed_positions = database.fixed_positions
+    memory_demanding = (
+        database.memory_demanding if hasattr(database, "memory_demanding") else False
+    )
+
+else:
+    annotation_type = None
+    fixed_positions = None
+
+
+
+cropped_positions={'leye':(49,72), 'reye':(49,38)}
+
+preprocessor_transformer = FaceCrop(cropped_image_size=(224,224), cropped_positions={'leye':(49,72), 'reye':(49,38)}, color_channel='rgb',fixed_positions=fixed_positions)
+
+transform_extra_arguments = (None if (cropped_positions is None or fixed_positions is not None) else (("annotations", "annotations"),))
+
+
+
+extractor_transformer = pytorch_loaded_model()
+
+algorithm = Distance(distance_function = scipy.spatial.distance.cosine,is_distance_function = True)
+
+
+# Chain the Transformers together
+transformer = make_pipeline(
+    wrap(["sample"], preprocessor_transformer,transform_extra_arguments=transform_extra_arguments),
+    wrap(["sample"], extractor_transformer)
+    # Add more transformers here if needed
+)
+
+
+# Assemble the Vanilla Biometric pipeline and execute
+pipeline = VanillaBiometricsPipeline(transformer, algorithm)
+transformer = pipeline.transformer
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/bob/bio/face/config/baseline/pytorch_pipe_v2.py b/bob/bio/face/config/baseline/pytorch_pipe_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a01f46784eac8f0b890efde0a291b505e42f092
--- /dev/null
+++ b/bob/bio/face/config/baseline/pytorch_pipe_v2.py
@@ -0,0 +1,113 @@
+import bob.bio.base
+from bob.bio.face.preprocessor import FaceCrop
+from bob.bio.face.extractor import pytorch_library_model
+from facenet_pytorch import InceptionResnetV1
+from bob.bio.base.algorithm import Distance
+from bob.bio.base.pipelines.vanilla_biometrics.legacy import BioAlgorithmLegacy
+import scipy.spatial
+from bob.bio.base.pipelines.vanilla_biometrics import Distance
+from sklearn.pipeline import make_pipeline
+from bob.pipelines import wrap
+from bob.bio.base.pipelines.vanilla_biometrics import VanillaBiometricsPipeline
+
+
+memory_demanding = False
+if "database" in locals():
+    annotation_type = database.annotation_type
+    fixed_positions = database.fixed_positions
+    memory_demanding = (
+        database.memory_demanding if hasattr(database, "memory_demanding") else False
+    )
+
+else:
+    annotation_type = None
+    fixed_positions = None
+
+
+
+cropped_positions={'leye':(49,72), 'reye':(49,38)}
+
+preprocessor_transformer = FaceCrop(cropped_image_size=(224,224), cropped_positions={'leye':(49,72), 'reye':(49,38)}, color_channel='rgb',fixed_positions=fixed_positions)
+
+transform_extra_arguments = (None if (cropped_positions is None or fixed_positions is not None) else (("annotations", "annotations"),))
+
+
+
+
+model = InceptionResnetV1(pretrained='vggface2').eval()
+extractor_transformer = pytorch_library_model(model=model)
+
+
+
+
+algorithm = Distance(distance_function = scipy.spatial.distance.cosine,is_distance_function = True)
+
+
+# Chain the Transformers together
+transformer = make_pipeline(
+    wrap(["sample"], preprocessor_transformer,transform_extra_arguments=transform_extra_arguments),
+    wrap(["sample"], extractor_transformer)
+    # Add more transformers here if needed
+)
+
+
+# Assemble the Vanilla Biometric pipeline and execute
+pipeline = VanillaBiometricsPipeline(transformer, algorithm)
+transformer = pipeline.transformer
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/bob/bio/face/config/baseline/tf_pipe.py b/bob/bio/face/config/baseline/tf_pipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e0a8ae2af52fcc85ac9e1a49e8bc3cf0cd2772a
--- /dev/null
+++ b/bob/bio/face/config/baseline/tf_pipe.py
@@ -0,0 +1,108 @@
+import bob.bio.base
+from bob.bio.face.preprocessor import FaceCrop
+from bob.bio.face.extractor import tf_model
+from bob.bio.base.algorithm import Distance
+from bob.bio.base.pipelines.vanilla_biometrics.legacy import BioAlgorithmLegacy
+import scipy.spatial
+from bob.bio.base.pipelines.vanilla_biometrics import Distance
+from sklearn.pipeline import make_pipeline
+from bob.pipelines import wrap
+from bob.bio.base.pipelines.vanilla_biometrics import VanillaBiometricsPipeline
+
+
+memory_demanding = False
+if "database" in locals():
+    annotation_type = database.annotation_type
+    fixed_positions = database.fixed_positions
+    memory_demanding = (
+        database.memory_demanding if hasattr(database, "memory_demanding") else False
+    )
+
+else:
+    annotation_type = None
+    fixed_positions = None
+
+
+# Preprocessor
+cropped_positions={'leye':(49,72), 'reye':(49,38)}
+
+preprocessor_transformer = FaceCrop(cropped_image_size=(160,160), cropped_positions={'leye':(49,72), 'reye':(49,38)}, color_channel='rgb',fixed_positions=fixed_positions)
+
+transform_extra_arguments = (None if (cropped_positions is None or fixed_positions is not None) else (("annotations", "annotations"),))
+
+
+# Extractor
+extractor_transformer = tf_model()
+
+# Algorithm
+algorithm = Distance(distance_function = scipy.spatial.distance.cosine,is_distance_function = True)
+
+
+# Chain the Transformers together
+transformer = make_pipeline(
+    wrap(["sample"], preprocessor_transformer,transform_extra_arguments=transform_extra_arguments),
+    wrap(["sample"], extractor_transformer)
+    # Add more transformers here if needed
+)
+
+
+# Assemble the Vanilla Biometric pipeline and execute
+pipeline = VanillaBiometricsPipeline(transformer, algorithm)
+transformer = pipeline.transformer
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/bob/bio/face/extractor/__init__.py b/bob/bio/face/extractor/__init__.py
index 45727cb29b6d443ce490e94ba148c57f7165f152..a5484ea167a1afedc102965674fe8ed2ae9f5cdb 100644
--- a/bob/bio/face/extractor/__init__.py
+++ b/bob/bio/face/extractor/__init__.py
@@ -1,6 +1,11 @@
 from .DCTBlocks import DCTBlocks
 from .GridGraph import GridGraph
 from .LGBPHS import LGBPHS
+from .mxnet_resnet import mxnet_model
+from .pytorch_model import pytorch_loaded_model
+from .pytorch_model import pytorch_library_model
+from .tf_model import tf_model
+from .opencv_caffe import opencv_model
 
 # gets sphinx autodoc done right - don't remove it
 def __appropriate__(*args):
@@ -20,5 +25,10 @@ __appropriate__(
     DCTBlocks,
     GridGraph,
     LGBPHS,
+    mxnet_model,
+    pytorch_loaded_model,
+    pytorch_library_model,
+    tf_model,
+    opencv_model,
     )
 __all__ = [_ for _ in dir() if not _.startswith('_')]
diff --git a/bob/bio/face/extractor/mxnet_resnet.py b/bob/bio/face/extractor/mxnet_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..aab050bfadd851e015e8b19e736a52aa854807e9
--- /dev/null
+++ b/bob/bio/face/extractor/mxnet_resnet.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch>
+
+"""Feature extraction resnet models using mxnet interface"""
+from sklearn.base import TransformerMixin, BaseEstimator
+from sklearn.utils import check_array
+import numpy as np
+import pkg_resources
+import os
+import mxnet as mx
+from mxnet import gluon
+import warnings
+from bob.extension import rc
+mxnet_resnet_directory = rc["bob.extractor_model.mxnet"]
+mxnet_weight_directory = rc["bob.extractor_weights.mxnet"]
+
+class mxnet_model(TransformerMixin, BaseEstimator):
+
+    """Extracts features using deep face recognition models under MxNet Interfaces.
+  
+  Users can download the pretrained face recognition models with MxNet Interface. The path to downloaded models should be specified before running the extractor (usually before running the pipeline file that includes the extractor). That is, set config of the model frame to :py:class:`bob.extractor_model.mxnet`, and set config of the parameters to :py:class:`bob.extractor_weights.mxnet`.
+  
+  .. code-block:: sh
+  
+    $ bob config set bob.extractor_model.mxnet /PATH/TO/MODEL/
+    $ bob config set bob.extractor_weights.mxnet /PATH/TO/WEIGHTS/
+  
+  Examples: (Pretrained ResNet models): `LResNet100E-IR,ArcFace@ms1m-refine-v2 <https://github.com/deepinsight/insightface>`_  
+  
+  The extracted features can be combined with different the algorithms.  
+
+  **Parameters:**
+  use_gpu: True or False.
+    """
+
+    def __init__(self, use_gpu=False, **kwargs):
+        super().__init__(**kwargs)
+        self.model = None
+        self.use_gpu = use_gpu
+
+        internal_path = pkg_resources.resource_filename(
+            __name__, os.path.join("data", "resnet"),
+        )
+
+        checkpoint_path = (
+            internal_path
+            if rc["bob.bio.face.models.mxnet_resnet"] is None
+            else rc["bob.bio.face.models.mxnet_resnet"]
+        )
+
+        self.checkpoint_path = checkpoint_path
+
+    def _load_model(self):
+
+        ctx = mx.gpu() if self.use_gpu else mx.cpu()
+
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            deserialized_net = gluon.nn.SymbolBlock.imports(mxnet_resnet_directory, ['data'], mxnet_weight_directory, ctx=ctx)
+               
+        self.model = deserialized_net
+
+    def transform(self, X):
+        """__call__(image) -> feature
+
+    Extracts the features from the given image.
+
+    **Parameters:**
+
+    image : 2D :py:class:`numpy.ndarray` (floats)
+      The image to extract the features from.
+
+    **Returns:**
+
+    feature : 2D, 3D, or 4D :py:class:`numpy.ndarray` (floats)
+      The list of features extracted from the image.
+    """
+    
+        if self.model is None:
+            self.load_model()
+
+        X = check_array(X, allow_nd=True)
+        X = mx.nd.array(X)
+
+        return self.model(X,).asnumpy()
+
+
+    def __getstate__(self):
+        # Handling unpicklable objects
+
+        d = self.__dict__.copy()
+        d["model"] = None
+        return d
+
+    def _more_tags(self):
+        return {"stateless": True, "requires_fit": False}
diff --git a/bob/bio/face/extractor/opencv_caffe.py b/bob/bio/face/extractor/opencv_caffe.py
new file mode 100644
index 0000000000000000000000000000000000000000..647fc724fe4e1087e488f89e24ed42de750552fa
--- /dev/null
+++ b/bob/bio/face/extractor/opencv_caffe.py
@@ -0,0 +1,105 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch>
+
+import bob.bio.base
+from bob.bio.face.preprocessor import FaceCrop
+
+from bob.bio.base.transformers.preprocessor import PreprocessorTransformer
+
+import cv2
+import numpy as np
+
+from bob.learn.tensorflow.utils.image import to_channels_last
+from sklearn.base import TransformerMixin, BaseEstimator
+from sklearn.utils import check_array
+
+from bob.extension import rc
+from functools import partial
+import pkg_resources
+import os
+
+from PIL import Image
+
+opencv_model_directory = rc["bob.extractor_model.opencv"]
+opencv_model_prototxt = rc["bob.extractor_weights.opencv"]
+
+
+class opencv_model(TransformerMixin, BaseEstimator):
+    """Extracts features using deep face recognition models under OpenCV Interface
+
+  Users can download the pretrained face recognition models with OpenCV Interface. The path to downloaded models should be specified before running the extractor (usually before running the pipeline file that includes the extractor). That is, set config of the model frame to :py:class:`bob.extractor_model.opencv`, and set config of the parameters to :py:class:`bob.extractor_weights.opencv`. 
+  
+  .. code-block:: sh
+  
+    $ bob config set bob.extractor_model.opencv /PATH/TO/MODEL/
+    $ bob config set bob.extractor_weights.opencv /PATH/TO/WEIGHTS/
+  
+  The extracted features can be combined with different the algorithms. 
+
+    .. note::
+       This structure only can be used for CAFFE pretrained model.
+
+  **Parameters:**
+  use_gpu: True or False.
+    """
+
+
+    def __init__(self, use_gpu=False, **kwargs):
+        super().__init__(**kwargs)
+        self.model = None
+        self.use_gpu = use_gpu
+
+        internal_path = pkg_resources.resource_filename(
+            __name__, os.path.join("data", "opencv_model"),
+        )
+
+        checkpoint_path = (
+            internal_path
+            if rc["bob.bio.face.models.opencv"] is None
+            else rc["bob.bio.face.models.opencv"]
+        )
+
+        self.checkpoint_path = checkpoint_path
+
+    def _load_model(self):
+
+        net = cv2.dnn.readNetFromCaffe(opencv_model_prototxt,opencv_model_directory)
+
+        self.model = net
+
+    def transform(self, X):
+        """__call__(image) -> feature
+
+    Extracts the features from the given image.
+
+    **Parameters:**
+
+    image : 2D :py:class:`numpy.ndarray` (floats)
+      The image to extract the features from.
+
+    **Returns:**
+
+    feature : 2D or 3D :py:class:`numpy.ndarray` (floats)
+      The list of features extracted from the image.
+    """
+    
+        if self.model is None:
+            self.load_model()
+
+        img = np.array(X)
+
+        self.model.setInput(img)
+     
+        return self.model.forward()
+
+
+    def __getstate__(self):
+        # Handling unpicklable objects
+
+        d = self.__dict__.copy()
+        d["model"] = None
+        return d
+
+    def _more_tags(self):
+        return {"stateless": True, "requires_fit": False}
diff --git a/bob/bio/face/extractor/pytorch_model.py b/bob/bio/face/extractor/pytorch_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..52c4f172d1c74f7018429261a072d8185bd69289
--- /dev/null
+++ b/bob/bio/face/extractor/pytorch_model.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch>
+
+import torch
+from bob.learn.tensorflow.utils.image import to_channels_last
+from sklearn.base import TransformerMixin, BaseEstimator
+from sklearn.utils import check_array
+from bob.extension import rc
+from functools import partial
+import pkg_resources
+import os
+import numpy as np
+import imp
+
+pytorch_model_directory = rc["bob.extractor_model.pytorch"]
+pytorch_weight_directory = rc["bob.extractor_weights.pytorch"]
+
+class pytorch_loaded_model(TransformerMixin, BaseEstimator):
+    """Extracts features using deep face recognition models under PyTorch Interface, especially for the models and weights that need to load by hand.
+    
+  Users can download the pretrained face recognition models with PyTorch Interface. The path to downloaded models should be specified before running the extractor (usually before running the pipeline file that includes the extractor). That is, set config of the model frame to :py:class:`bob.extractor_model.pytorch`, and set config of the parameters to :py:class:`bob.extractor_weights.pytorch`. 
+  
+  .. code-block:: sh
+  
+    $ bob config set bob.extractor_model.pytorch /PATH/TO/MODEL/
+    $ bob config set bob.extractor_weights.pytorch /PATH/TO/WEIGHTS/
+  
+  The extracted features can be combined with different the algorithms. 
+
+  **Parameters:**
+  use_gpu: True or False.
+    """ 
+
+    def __init__(self, use_gpu=False, **kwargs):
+        super().__init__(**kwargs)
+        self.model = None
+        self.use_gpu = use_gpu
+
+        internal_path = pkg_resources.resource_filename(
+            __name__, os.path.join("data", "resnet"),
+        )
+
+        checkpoint_path = (
+            internal_path
+            if rc["bob.bio.face.models.pytorchmodel"] is None
+            else rc["bob.bio.face.models.pytorchmodel"]
+        )
+
+        self.checkpoint_path = checkpoint_path
+        self.device = None
+
+    def _load_model(self):
+
+        MainModel = imp.load_source('MainModel', pytorch_model_directory)
+        network = torch.load(pytorch_weight_directory)
+        network.eval()
+        
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        
+        network.to(self.device)
+
+        self.model = network
+
+    def transform(self, X):
+        """__call__(image) -> feature
+
+    Extracts the features from the given image.
+
+    **Parameters:**
+
+    image : 2D :py:class:`numpy.ndarray` (floats)
+      The image to extract the features from.
+
+    **Returns:**
+
+    feature : 2D or 3D :py:class:`numpy.ndarray` (floats)
+      The list of features extracted from the image.
+    """
+    
+        if self.model is None:
+            self.load_model()
+
+        X = torch.Tensor(X)
+
+        return self.model(X).detach().numpy()
+
+
+    def __getstate__(self):
+        # Handling unpicklable objects
+
+        d = self.__dict__.copy()
+        d["model"] = None
+        return d
+
+    def _more_tags(self):
+        return {"stateless": True, "requires_fit": False}
+        
+        
+        
+        
+        
+        
+        
+        
+class pytorch_library_model(TransformerMixin, BaseEstimator):
+    """Extracts features using deep face recognition with registered model frames in the PyTorch Library. 
+    
+  Users can import the pretrained face recognition models from PyTorch library. The model should be called in the pipeline. Example: `facenet_pytorch <https://github.com/timesler/facenet-pytorch>`_
+
+  The extracted features can be combined with different the algorithms.  
+
+  **Parameters:**
+  model: pytorch model calling from library.
+  use_gpu: True or False.
+    """
+
+    def __init__(self, model=None, use_gpu=False, **kwargs):
+        super().__init__(**kwargs)
+        self.model = model
+        self.use_gpu = use_gpu
+
+        internal_path = pkg_resources.resource_filename(
+            __name__, os.path.join("data", "resnet"),
+        )
+
+        checkpoint_path = (
+            internal_path
+            if rc["bob.bio.face.models.pytorchmodel"] is None
+            else rc["bob.bio.face.models.pytorchmodel"]
+        )
+
+        self.checkpoint_path = checkpoint_path
+        self.device = None
+
+    def transform(self, X):
+        """__call__(image) -> feature
+
+    Extracts the features from the given image.
+
+    **Parameters:**
+
+    image : 2D :py:class:`numpy.ndarray` (floats)
+      The image to extract the features from.
+
+    **Returns:**
+
+    feature : 2D or 3D :py:class:`numpy.ndarray` (floats)
+      The list of features extracted from the image.
+    """
+
+        X = torch.Tensor(X)
+
+        return self.model(X).detach().numpy()
+
+
+    def __getstate__(self):
+        # Handling unpicklable objects
+
+        d = self.__dict__.copy()
+        d["model"] = None
+        return d
+
+    def _more_tags(self):
+        return {"stateless": True, "requires_fit": False}
diff --git a/bob/bio/face/extractor/tf_model.py b/bob/bio/face/extractor/tf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e83ec389dafb05b1d9031c52338e235a9086ef8
--- /dev/null
+++ b/bob/bio/face/extractor/tf_model.py
@@ -0,0 +1,95 @@
+#!/usr/bin/env python
+# vim: set fileencoding=utf-8 :
+# Yu Linghu & Xinyi Zhang <yu.linghu@uzh.ch, xinyi.zhang@uzh.ch>
+
+import tensorflow as tf
+from bob.extension import rc
+from bob.learn.tensorflow.utils.image import to_channels_last
+from sklearn.base import TransformerMixin, BaseEstimator
+from sklearn.utils import check_array
+from tensorflow.keras import Sequential
+from tensorflow.keras.layers.experimental import preprocessing
+from functools import partial
+import pkg_resources
+import os
+import numpy as np
+from tensorflow import keras
+
+tf_model_directory = rc["bob.extractor_model.tf"]
+
+class tf_model(TransformerMixin, BaseEstimator):
+    """Extracts features using deep face recognition models under TensorFlow Interface.
+
+  Users can download the pretrained face recognition models with TensorFlow Interface. The path to downloaded models should be specified before running the extractor (usually before running the pipeline file that includes the extractor). That is, set config of the model to :py:class:`bob.extractor_model.tf`. 
+   
+  .. code-block:: sh
+  
+    $ bob config set bob.extractor_model.tf /PATH/TO/MODEL/
+  
+  The extracted features can be combined with different the algorithms. 
+
+
+  **Parameters:**
+  use_gpu: True or False.
+    """
+
+    def __init__(self, use_gpu=False, **kwargs):
+        super().__init__(**kwargs)
+        self.model = None
+        self.use_gpu = use_gpu
+
+        internal_path = pkg_resources.resource_filename(
+            __name__, os.path.join("data", "resnet"),
+        )
+
+        checkpoint_path = (
+            internal_path
+            if rc["bob.bio.face.models.tfmodel"] is None
+            else rc["bob.bio.face.models.tfmodel"]
+        )
+
+        self.checkpoint_path = checkpoint_path
+
+    def _load_model(self):
+
+        model = tf.keras.models.load_model(tf_model_directory)
+
+        self.model = model
+
+    def transform(self, X):
+        """__call__(image) -> feature
+
+    Extracts the features from the given image.
+
+    **Parameters:**
+
+    image : 2D :py:class:`numpy.ndarray` (floats)
+      The image to extract the features from.
+
+    **Returns:**
+
+    feature : 2D or 3D :py:class:`numpy.ndarray` (floats)
+      The list of features extracted from the image.
+    """
+    
+        if self.model is None:
+            self.load_model()
+
+        X = check_array(X, allow_nd=True)
+        X = tf.convert_to_tensor(X)
+        X = to_channels_last(X)
+        predict = self.model.predict(X)
+
+
+        return predict
+
+
+    def __getstate__(self):
+        # Handling unpicklable objects
+
+        d = self.__dict__.copy()
+        d["model"] = None
+        return d
+
+    def _more_tags(self):
+        return {"stateless": True, "requires_fit": False}
diff --git a/doc/baselines.rst b/doc/baselines.rst
index 2217e5af8bf96bb90dc7b6964dee14f5fbdca275..88c42a629cd6a9eb1ccd18fe92146aa3bbfc499c 100644
--- a/doc/baselines.rst
+++ b/doc/baselines.rst
@@ -56,3 +56,17 @@ Deep learning baselines
 * ``inception-resnetv1-casiawebface``: Inception Resnet v1 model trained using the Casia Web dataset in the context of the work published by [TFP18]_
 
 * ``arcface-insightface``: Arcface model from `Insightface <https://github.com/deepinsight/insightface>`_
+
+
+Deep Learning with different interfaces baselines
+=================================================
+
+* ``mxnet_pipe``: Arcface Resnet Model using MxNet Interfaces from `Insightface <https://github.com/deepinsight/insightface>`_
+
+* ``pytorch_pipe_v1``: Pytorch network that extracs 1000-dimensional featrues, trained by Manual Gunther, as described in [LGB18]_
+
+* ``pytorch_pipe_v2``: Inception Resnet face recognition model from `facenet_pytorch <https://github.com/timesler/facenet-pytorch>`_
+
+* ``tf_pipe``: Inception Resnet v2 model trained using the MSCeleb dataset in the context of the work published by [TFP18]_
+
+* ``opencv_pipe``: VGG Face descriptor pretrained models, i.e. `Caffe model <https://www.robots.ox.ac.uk/~vgg/software/vgg_face/>`_
diff --git a/doc/implemented.rst b/doc/implemented.rst
index a53938b64b5b8527bb7be9edaafc01a9a6ebe0ee..75acc7129ef0132eb8ad762c73ddf7e53cb5b204 100644
--- a/doc/implemented.rst
+++ b/doc/implemented.rst
@@ -13,7 +13,6 @@ Databases
 .. autosummary::
    bob.bio.face.database.ARFaceBioDatabase
    bob.bio.face.database.AtntBioDatabase
-   bob.bio.face.database.CasiaAfricaDatabase
    bob.bio.face.database.MobioDatabase
    bob.bio.face.database.ReplayBioDatabase
    bob.bio.face.database.ReplayMobileBioDatabase
@@ -35,6 +34,7 @@ Face Image Annotators
    bob.bio.face.annotator.BobIpFacedetect
    bob.bio.face.annotator.BobIpFlandmark
    bob.bio.face.annotator.BobIpMTCNN
+   bob.bio.face.annotator.BobIpTinyface
 
 
 Image Preprocessors
@@ -57,7 +57,11 @@ Image Feature Extractors
    bob.bio.face.extractor.DCTBlocks
    bob.bio.face.extractor.GridGraph
    bob.bio.face.extractor.LGBPHS
-
+   bob.bio.face.extractor.mxnet_model
+   bob.bio.face.extractor.pytorch_loaded_model
+   bob.bio.face.extractor.pytorch_library_model
+   bob.bio.face.extractor.tf_model
+   bob.bio.face.extractor.opencv_model
 
 Face Recognition Algorithms
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/doc/references.rst b/doc/references.rst
index ea60e1ad5e565618d66c19fe260dabb1b4b4df45..28b98d5f0cbb37e59bfb9f9623b31b966493ed85 100644
--- a/doc/references.rst
+++ b/doc/references.rst
@@ -17,3 +17,4 @@ References
 .. [ZSQ09]  *W. Zhang, S. Shan, L. Qing, X. Chen and W. Gao*. **Are Gabor phases really useless for face recognition?** Pattern Analysis & Applications, 12:301-307, 2009.
 .. [TFP18] de Freitas Pereira, Tiago, André Anjos, and Sébastien Marcel. "Heterogeneous face recognition using domain specific units." IEEE Transactions on Information Forensics and Security 14.7 (2018): 1803-1816.
 .. [HRM06]   *G. Heusch, Y. Rodriguez, and S. Marcel*. **Local Binary Patterns as an Image Preprocessing for Face Authentication**. In IEEE International Conference on Automatic Face and Gesture Recognition (AFGR), 2006.
+.. [LGB18]    *C. Li, M. Gunther and T. E. Boult*. **ECLIPSE: Ensembles of Centroids Leveraging Iteratively Processed Spatial Eclipse Clustering**. 2018 IEEE Winter Conference on Applications of Computer Vision (WACV), Lake Tahoe, NV, USA, 2018, pp. 131-140, doi: 10.1109/WACV.2018.00021.
diff --git a/requirements.txt b/requirements.txt
index 2bda4901dfc18d3de7932c96ca1afbe84f8146c3..02e30ac5109f4d32cb93c768ec7a78195f25de94 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -18,6 +18,8 @@ bob.bio.base
 bob.ip.facedetect
 bob.pipelines
 matplotlib   # for plotting
+mxnet
+opencv-python
 six
 scikit-image
 scikit-learn # for pipelines Tranformers
diff --git a/setup.py b/setup.py
index ded13e23b4f4326f28c48532700311235313affa..8a1c7be330b11bceddfba1b65ed1c3cfc1b132bd 100644
--- a/setup.py
+++ b/setup.py
@@ -121,12 +121,14 @@ setup(
             "facedetect-eye-estimate  = bob.bio.face.config.annotator.facedetect_eye_estimate:annotator",
             "flandmark                = bob.bio.face.config.annotator.flandmark:annotator",
             "mtcnn                    = bob.bio.face.config.annotator.mtcnn:annotator",
+            "tinyface                 = bob.bio.face.config.annotator.tinyface:annotator",
         ],
         "bob.bio.transformer": [
             "facedetect-eye-estimate = bob.bio.face.config.annotator.facedetect_eye_estimate:transformer",
             "facedetect = bob.bio.face.config.annotator.facedetect:transformer",
             "flandmark = bob.bio.face.config.annotator.flandmark:annotator",
             "mtcnn = bob.bio.face.config.annotator.mtcnn:transformer",
+            "tinyface = bob.bio.face.config.annotator.tinyface:transformer",
             "facenet-sanderberg = bob.bio.face.config.baseline.facenet_sanderberg:transformer",
             "inception-resnetv1-casiawebface = bob.bio.face.config.baseline.inception_resnetv1_casiawebface:transformer",
             "inception-resnetv2-casiawebface = bob.bio.face.config.baseline.inception_resnetv2_casiawebface:transformer",
@@ -136,6 +138,11 @@ setup(
             "gabor-graph = bob.bio.face.config.baseline.gabor_graph:transformer",
             "lgbphs = bob.bio.face.config.baseline.lgbphs:transformer",
             "dummy = bob.bio.face.config.baseline.dummy:transformer",
+            "mxnet-pipe = bob.bio.face.config.baseline.mxnet_pipe:transformer",
+            "pytorch-pipe-v1 = bob.bio.face.config.baseline.pytorch_pipe_v1:transformer",
+            "pytorch-pipe-v2 = bob.bio.face.config.baseline.pytorch_pipe_v2:transformer",
+            "tf-pipe = bob.bio.face.config.baseline.ty_pipe:transformer",
+            "opencv-pipe = bob.bio.face.config.baseline.opencv_pipe:transformer",
         ],
         # baselines
         "bob.bio.pipeline": [
@@ -150,8 +157,12 @@ setup(
             "lda = bob.bio.face.config.baseline.lda:pipeline",
             "dummy = bob.bio.face.config.baseline.dummy:pipeline",
             "resnet50-msceleb-arcface-2021 = bob.bio.face.config.baseline.resnet50_msceleb_arcface_2021:pipeline",
-            "resnet50-vgg2-arcface-2021 = bob.bio.face.config.baseline.resnet50_vgg2_arcface_2021:pipeline",
             "mobilenetv2-msceleb-arcface-2021 = bob.bio.face.config.baseline.mobilenetv2_msceleb_arcface_2021",
+            "mxnet-pipe = bob.bio.face.config.baseline.mxnet_pipe:pipeline",
+            "pytorch-pipe-v1 = bob.bio.face.config.baseline.pytorch_pipe_v1:pipeline",
+            "pytorch-pipe-v2 = bob.bio.face.config.baseline.pytorch_pipe_v2:pipeline",
+            "tf-pipe = bob.bio.face.config.baseline.ty_pipe:pipeline",
+            "opencv-pipe = bob.bio.face.config.baseline.opencv_pipe:pipeline",
         ],
         "bob.bio.config": [
             "facenet-sanderberg = bob.bio.face.config.baseline.facenet_sanderberg",
@@ -163,6 +174,11 @@ setup(
             "arcface-insightface = bob.bio.face.config.baseline.arcface_insightface",
             "lgbphs = bob.bio.face.config.baseline.lgbphs",
             "lda = bob.bio.face.config.baseline.lda",
+            "mxnet-pipe = bob.bio.face.config.baseline.mxnet_pipe",
+            "pytorch-pipe-v1 = bob.bio.face.config.baseline.pytorch_pipe_v1",
+            "pytorch-pipe-v2 = bob.bio.face.config.baseline.pytorch_pipe_v2",
+            "tf-pipe = bob.bio.face.config.baseline.ty_pipe",
+            "opencv-pipe = bob.bio.face.config.baseline.opencv_pipe",
             "arface            = bob.bio.face.config.database.arface",
             "atnt              = bob.bio.face.config.database.atnt",
             "gbu               = bob.bio.face.config.database.gbu",
@@ -184,7 +200,6 @@ setup(
             "pola-thermal = bob.bio.face.config.database.pola_thermal",
             "cbsr-nir-vis-2 = bob.bio.face.config.database.cbsr_nir_vis_2",
             "resnet50-msceleb-arcface-2021 = bob.bio.face.config.baseline.resnet50_msceleb_arcface_2021",
-            "resnet50-vgg2-arcface-2021 = bob.bio.face.config.baseline.resnet50_vgg2_arcface_2021",
             "mobilenetv2-msceleb-arcface-2021 = bob.bio.face.config.baseline.mobilenetv2_msceleb_arcface_2021",
         ],
         "bob.bio.cli": [