diff --git a/bob/ip/facedetect/tests/test_tinyface.py b/bob/ip/facedetect/tests/test_tinyface.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cc3b090ed6c8911a940303f5c95bb29c7e3b55d
--- /dev/null
+++ b/bob/ip/facedetect/tests/test_tinyface.py
@@ -0,0 +1,54 @@
+from bob.ip.facedetect.tests.utils import is_library_available
+
+import bob.io.image
+import bob.io.base
+import bob.io.base.test_utils
+
+import numpy
+
+
+# An image with one face
+face_image = bob.io.base.load(
+    bob.io.base.test_utils.datafile("testimage.jpg", "bob.ip.facedetect")
+)
+
+# An image with 6 faces
+face_image_multiple = bob.io.base.load(
+    bob.io.base.test_utils.datafile("test_image_multi_face.png", "bob.ip.facedetect")
+)
+
+
+def _assert_tinyface_annotations(annot):
+    """
+    Verifies that TinyFace returns the correct coordinates for ``testimage``.
+    """
+    assert len(annot) == 1, f"len: {len(annot)}; {annot}"
+    face = annot[0]
+    assert [int(x) for x in face["topleft"]] == [59, 57], face
+    assert [int(x) for x in face["bottomright"]] == [338, 284], face
+    assert numpy.allclose(
+        [x for x in face["reye"]], [162.23, 125.89], atol=10e-2, rtol=10e-2
+    )
+    assert numpy.allclose(
+        [x for x in face["leye"]], [162.23, 215.89], atol=10e-2, rtol=10e-2
+    )
+
+
+@is_library_available("mxnet")
+def test_tinyface():
+    """TinyFace should annotate one face correctly."""
+    from bob.ip.facedetect.tinyface import TinyFacesDetector
+
+    tinyface_annotator = TinyFacesDetector()
+    annot = tinyface_annotator.detect(face_image)
+    _assert_tinyface_annotations(annot)
+
+
+@is_library_available("mxnet")
+def test_tinyface_multiface():
+    """TinyFace should find multiple faces in an image."""
+    from bob.ip.facedetect.tinyface import TinyFacesDetector
+
+    tinyface_annotator = TinyFacesDetector()
+    annot = tinyface_annotator.detect(face_image_multiple)
+    assert len(annot) == 6
diff --git a/bob/ip/facedetect/tinyface.py b/bob/ip/facedetect/tinyface.py
new file mode 100644
index 0000000000000000000000000000000000000000..1efc682885eafe96557dd0a13bf11d2fb5418938
--- /dev/null
+++ b/bob/ip/facedetect/tinyface.py
@@ -0,0 +1,255 @@
+import mxnet as mx
+from mxnet import gluon
+from bob.ip.color import gray_to_rgb
+import logging
+import numpy as np
+import cv2 as cv
+import pickle
+import os, sys
+from collections import namedtuple
+import time
+from bob.io.image import to_matplotlib
+import pkg_resources
+from bob.extension import rc
+from bob.extension.download import download_and_unzip
+import os
+
+logger = logging.getLogger(__name__)
+Batch = namedtuple("Batch", ["data"])
+
+
+class TinyFacesDetector:
+
+    """TinyFace face detector. Original Model is ``ResNet101`` from 
+    https://github.com/peiyunh/tiny. Please check for details. The 
+    model used in this section is the MxNet version from 
+    https://github.com/chinakook/hr101_mxnet.
+
+    Attributes
+    ----------
+    prob_thresh: float
+        Thresholds are a trade-off between false positives and missed detections.
+    """
+
+    def __init__(self, prob_thresh=0.5):
+
+        internal_path = pkg_resources.resource_filename(
+            __name__, os.path.join("data", "tinyface_detector"),
+        )
+
+        checkpoint_path = (
+            internal_path
+            if rc["bob.ip.facedetect.models.tinyface_detector"] is None
+            else rc["bob.ip.facedetect.models.tinyface_detector"]
+        )
+
+        urls = [
+            "https://www.idiap.ch/software/bob/data/bob/bob.ip.facedetect/master/tinyface_detector.tar.gz"
+        ]
+
+        os.makedirs(checkpoint_path, exist_ok=True)
+        download_and_unzip(
+            urls, os.path.join(checkpoint_path, "tinyface_detector.tar.gz")
+        )
+
+        self.checkpoint_path = checkpoint_path
+
+        self.MAX_INPUT_DIM = 5000.0
+        self.prob_thresh = prob_thresh
+        self.nms_thresh = 0.1
+        self.model_root = pkg_resources.resource_filename(
+            __name__, self.checkpoint_path
+        )
+
+        sym, arg_params, aux_params = mx.model.load_checkpoint(
+            os.path.join(self.checkpoint_path, "hr101"), 0
+        )
+        all_layers = sym.get_internals()
+
+        meta_file = open(os.path.join(self.checkpoint_path, "meta.pkl"), "rb")
+        self.clusters = pickle.load(meta_file)
+        self.averageImage = pickle.load(meta_file)
+        meta_file.close()
+        self.clusters_h = self.clusters[:, 3] - self.clusters[:, 1] + 1
+        self.clusters_w = self.clusters[:, 2] - self.clusters[:, 0] + 1
+        self.normal_idx = np.where(self.clusters[:, 4] == 1)
+
+        self.mod = mx.mod.Module(
+            symbol=all_layers["fusex_output"], data_names=["data"], label_names=None
+        )
+        self.mod.bind(
+            for_training=False,
+            data_shapes=[("data", (1, 3, 224, 224))],
+            label_shapes=None,
+            force_rebind=False,
+        )
+        self.mod.set_params(
+            arg_params=arg_params, aux_params=aux_params, force_init=False
+        )
+
+    @staticmethod
+    def _nms(dets, prob_thresh):
+
+        x1 = dets[:, 0]
+        y1 = dets[:, 1]
+        x2 = dets[:, 2]
+        y2 = dets[:, 3]
+        scores = dets[:, 4]
+
+        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+
+        order = scores.argsort()[::-1]
+
+        keep = []
+        while order.size > 0:
+            i = order[0]
+            keep.append(i)
+            xx1 = np.maximum(x1[i], x1[order[1:]])
+            yy1 = np.maximum(y1[i], y1[order[1:]])
+            xx2 = np.minimum(x2[i], x2[order[1:]])
+            yy2 = np.minimum(y2[i], y2[order[1:]])
+            w = np.maximum(0.0, xx2 - xx1 + 1)
+            h = np.maximum(0.0, yy2 - yy1 + 1)
+            inter = w * h
+
+            ovr = inter / (areas[i] + areas[order[1:]] - inter)
+            inds = np.where(ovr <= prob_thresh)[0]
+
+            order = order[inds + 1]
+        return keep
+
+    def detect(self, img):
+        """Detects and annotates all faces in the image.
+
+        Parameters
+        ----------
+        image : numpy.ndarray
+            An RGB image in Bob format.
+
+        Returns
+        -------
+        list
+            A list of annotations. Annotations are dictionaries that contain the
+            following keys: ``topleft``, ``bottomright``, ``reye``, ``leye``. 
+            (``reye`` and ``leye`` are the estimated results, not captured by the 
+            model.)
+        """
+        raw_img = img
+        if len(raw_img.shape) == 2:
+            raw_img = gray_to_rgb(raw_img)
+        assert img.shape[0] == 3, img.shape
+
+        raw_img = to_matplotlib(raw_img)
+        raw_img = raw_img[..., ::-1]
+
+        raw_h = raw_img.shape[0]
+        raw_w = raw_img.shape[1]
+
+        raw_img = cv.cvtColor(raw_img, cv.COLOR_BGR2RGB)
+        raw_img_f = raw_img.astype(np.float32)
+
+        min_scale = min(
+            np.floor(np.log2(np.max(self.clusters_w[self.normal_idx] / raw_w))),
+            np.floor(np.log2(np.max(self.clusters_h[self.normal_idx] / raw_h))),
+        )
+        max_scale = min(1.0, -np.log2(max(raw_h, raw_w) / self.MAX_INPUT_DIM))
+
+        scales_down = np.arange(min_scale, 0 + 0.0001, 1.0)
+        scales_up = np.arange(0.5, max_scale + 0.0001, 0.5)
+        scales_pow = np.hstack((scales_down, scales_up))
+        scales = np.power(2.0, scales_pow)
+
+        start = time.time()
+        bboxes = np.empty(shape=(0, 5))
+        for s in scales[::-1]:
+            img = cv.resize(raw_img_f, (0, 0), fx=s, fy=s)
+            img = np.transpose(img, (2, 0, 1))
+            img = img - self.averageImage
+
+            tids = []
+            if s <= 1.0:
+                tids = list(range(4, 12))
+            else:
+                tids = list(range(4, 12)) + list(range(18, 25))
+            ignoredTids = list(set(range(0, self.clusters.shape[0])) - set(tids))
+            img_h = img.shape[1]
+            img_w = img.shape[2]
+            img = img[np.newaxis, :]
+
+            self.mod.reshape(data_shapes=[("data", (1, 3, img_h, img_w))])
+            self.mod.forward(Batch([mx.nd.array(img)]))
+            self.mod.get_outputs()[0].wait_to_read()
+            fusex_res = self.mod.get_outputs()[0]
+
+            score_cls = mx.nd.slice_axis(
+                fusex_res, axis=1, begin=0, end=25, name="score_cls"
+            )
+            score_reg = mx.nd.slice_axis(
+                fusex_res, axis=1, begin=25, end=None, name="score_reg"
+            )
+            prob_cls = mx.nd.sigmoid(score_cls)
+
+            prob_cls_np = prob_cls.asnumpy()
+            prob_cls_np[0, ignoredTids, :, :] = 0.0
+
+            _, fc, fy, fx = np.where(prob_cls_np > self.prob_thresh)
+
+            cy = fy * 8 - 1
+            cx = fx * 8 - 1
+            ch = self.clusters[fc, 3] - self.clusters[fc, 1] + 1
+            cw = self.clusters[fc, 2] - self.clusters[fc, 0] + 1
+
+            Nt = self.clusters.shape[0]
+
+            score_reg_np = score_reg.asnumpy()
+            tx = score_reg_np[0, 0:Nt, :, :]
+            ty = score_reg_np[0, Nt : 2 * Nt, :, :]
+            tw = score_reg_np[0, 2 * Nt : 3 * Nt, :, :]
+            th = score_reg_np[0, 3 * Nt : 4 * Nt, :, :]
+
+            dcx = cw * tx[fc, fy, fx]
+            dcy = ch * ty[fc, fy, fx]
+            rcx = cx + dcx
+            rcy = cy + dcy
+            rcw = cw * np.exp(tw[fc, fy, fx])
+            rch = ch * np.exp(th[fc, fy, fx])
+
+            score_cls_np = score_cls.asnumpy()
+            scores = score_cls_np[0, fc, fy, fx]
+
+            tmp_bboxes = np.vstack(
+                (rcx - rcw / 2, rcy - rch / 2, rcx + rcw / 2, rcy + rch / 2)
+            )
+            tmp_bboxes = np.vstack((tmp_bboxes / s, scores))
+            tmp_bboxes = tmp_bboxes.transpose()
+            bboxes = np.vstack((bboxes, tmp_bboxes))
+
+        refind_idx = self._nms(bboxes, self.nms_thresh)
+        refind_bboxes = bboxes[refind_idx]
+        refind_bboxes = refind_bboxes.astype(np.int32)
+
+        annotations = refind_bboxes
+        annots = []
+        for i in range(len(refind_bboxes)):
+            topleft = float(annotations[i][1]), float(annotations[i][0])
+            bottomright = float(annotations[i][3]), float(annotations[i][2])
+            width = float(annotations[i][2]) - float(annotations[i][0])
+            length = float(annotations[i][3]) - float(annotations[i][1])
+            right_eye = (
+                (0.37) * length + float(annotations[i][1]),
+                (0.3) * width + float(annotations[i][0]),
+            )
+            left_eye = (
+                (0.37) * length + float(annotations[i][1]),
+                (0.7) * width + float(annotations[i][0]),
+            )
+            annots.append(
+                {
+                    "topleft": topleft,
+                    "bottomright": bottomright,
+                    "reye": right_eye,
+                    "leye": left_eye,
+                }
+            )
+
+        return annots
diff --git a/doc/img/detect_faces_tinyface.png b/doc/img/detect_faces_tinyface.png
new file mode 100644
index 0000000000000000000000000000000000000000..990b484b119344648c21f519f9f7c12598433772
Binary files /dev/null and b/doc/img/detect_faces_tinyface.png differ
diff --git a/doc/index.rst b/doc/index.rst
index 1bb83ec2800f300bd695c9c14ebd759c14966125..808730a6a8bf61a004bd4b3271adaaab5dae6c32 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -35,6 +35,7 @@ Documentation
 
    guide
    mtcnn
+   tinyface
    py_api
 
 
diff --git a/doc/plot/detect_faces_tinyface.py b/doc/plot/detect_faces_tinyface.py
new file mode 100644
index 0000000000000000000000000000000000000000..74e274e5aab6a07f21e0b83962025817f9bda68c
--- /dev/null
+++ b/doc/plot/detect_faces_tinyface.py
@@ -0,0 +1,41 @@
+import matplotlib.pyplot as plt
+from bob.io.base import load
+from bob.io.base.test_utils import datafile
+from bob.io.image import imshow
+from bob.ip.facedetect.tinyface import TinyFacesDetector
+from matplotlib.patches import Rectangle
+
+# load colored test image
+color_image = load(datafile("test_image_multi_face.png", "bob.ip.facedetect"))
+is_mxnet_available = True
+try:
+    import mxnet
+except Exception:
+    is_mxnet_available = False
+
+if not is_mxnet_available:
+    imshow(color_image)
+else:
+
+    # detect all faces
+    detector = TinyFacesDetector()
+    detections = detector.detect(color_image)
+
+    imshow(color_image)
+    plt.axis("off")
+
+    for annotations in detections:
+        topleft = annotations["topleft"]
+        bottomright = annotations["bottomright"]
+        size = bottomright[0] - topleft[0], bottomright[1] - topleft[1]
+        # draw bounding boxes
+        plt.gca().add_patch(
+            Rectangle(
+                topleft[::-1],
+                size[1],
+                size[0],
+                edgecolor="b",
+                facecolor="none",
+                linewidth=2,
+            )
+        )
\ No newline at end of file
diff --git a/doc/py_api.rst b/doc/py_api.rst
index 443dcf784c91d7a1ecdd4a58ad1e2609f8ba52e3..004c6f12db7db47d2d5a443a03447d9549954152 100644
--- a/doc/py_api.rst
+++ b/doc/py_api.rst
@@ -40,4 +40,4 @@ Detailed Information
 --------------------
 
 .. automodule:: bob.ip.facedetect
-.. automodule:: bob.ip.facedetect.mtcnn
+.. automodule:: bob.ip.facedetect.mtcnn
\ No newline at end of file
diff --git a/doc/tinyface.rst b/doc/tinyface.rst
new file mode 100644
index 0000000000000000000000000000000000000000..c29f1ecab65a9f1f36623999ec399694eb696a78
--- /dev/null
+++ b/doc/tinyface.rst
@@ -0,0 +1,31 @@
+
+.. _bob.ip.facedetect.tinyface:
+
+==============================
+ Face detection using TinyFace
+==============================
+
+This package comes with a TinyFace face detector. The Original Model is ``ResNet101`` 
+from https://github.com/peiyunh/tiny. Please check for more details on TinyFace. The 
+model is converted into MxNet Interface and the code used to implement the model are 
+from https://github.com/chinakook/hr101_mxnet.
+
+
+Implementation
+--------------
+
+See below for an example on how to use
+`bob.ip.facedetect.tinyface.TinyFacesDetector`:
+
+.. literalinclude:: plot/detect_faces_tinyface.py
+   :linenos:
+
+This face detector can be used for detecting single or multiple faces. If there are more than one face, the first entry of the returned annotation supposed to be the largest face in the image. 
+  
+  
+.. figure:: img/detect_faces_tinyface.png
+  :figwidth: 75%
+  :align: center
+  :alt: Multi-Face Detection results using TinyFace.
+
+  Multiple faces are detected by TinyFace.
diff --git a/requirements.txt b/requirements.txt
index c82bea78eb8e733e5aa1cf4ab3f9dc5b361ef38c..8d23d68d827832b2f84f460d90713fa7ce843287 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,4 @@ bob.io.image
 bob.ip.base
 bob.ip.color
 scikit-image
-scipy
+scipy
\ No newline at end of file