From d39402f9c737281cc814f9fa5d4a275a3e82fa38 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Thu, 25 Mar 2021 16:12:26 +0100
Subject: [PATCH] Add SWAN database

---
 bob/pad/face/config/swan.py         | 32 +++++++++++++
 bob/pad/face/database/__init__.py   |  2 +
 bob/pad/face/database/database.py   |  7 ++-
 bob/pad/face/database/swan.py       | 56 ++++++++++++++++++++++
 bob/pad/face/test/test_databases.py | 73 ++++++++++++++++++++++++-----
 doc/links.rst                       |  1 +
 setup.py                            |  4 ++
 7 files changed, 161 insertions(+), 14 deletions(-)
 create mode 100644 bob/pad/face/config/swan.py
 create mode 100644 bob/pad/face/database/swan.py

diff --git a/bob/pad/face/config/swan.py b/bob/pad/face/config/swan.py
new file mode 100644
index 00000000..5f1cad9f
--- /dev/null
+++ b/bob/pad/face/config/swan.py
@@ -0,0 +1,32 @@
+"""The Swan_ Database.
+
+To configure the location of the database on your computer, run::
+
+    bob config set bob.db.swan.directory /path/to/swan/database
+
+
+The Idiap part of the dataset comprises 150 subjects that are captured in six
+different sessions reflecting real-life scenarios of smartphone assisted
+authentication. One of the unique features of this dataset is that it is
+collected in four different geographic locations representing a diverse
+population and ethnicity. Additionally, it also contains a multimodal
+Presentation Attack (PA) or spoofing dataset using low-cost Presentation Attack
+Instruments (PAI) such as print and electronic display attacks. The novel
+acquisition protocols and the diversity of the data subjects collected from
+different geographic locations will allow developing a novel algorithm for
+either unimodal or multimodal biometrics.
+
+PAD protocols are created according to the SWAN-PAD-protocols document.
+Bona-fide session 2 data is split into 3 sets of training, development, and
+evaluation. The bona-fide data from sessions 3,4,5,6 are used for evaluation as
+well. PA samples are randomly split into 3 sets of training, development, and
+evaluation. All the random splits are done 10 times to created 10 different
+protocols. The PAD protocols contain only one type of attacks. For convenience,
+PA_F and PA_V protocols are created for face and voice, respectively which
+contain all the attacks.
+
+.. include:: links.rst
+"""
+from bob.pad.face.database import SwanPadDatabase
+
+database = SwanPadDatabase()
diff --git a/bob/pad/face/database/__init__.py b/bob/pad/face/database/__init__.py
index 1318192d..082597cb 100644
--- a/bob/pad/face/database/__init__.py
+++ b/bob/pad/face/database/__init__.py
@@ -5,6 +5,7 @@ from .casiasurf import CasiaSurfPadDatabase
 from .maskattack import MaskAttackPadDatabase
 from .replay_attack import ReplayAttackPadDatabase
 from .replay_mobile import ReplayMobilePadDatabase
+from .swan import SwanPadDatabase
 
 
 # gets sphinx autodoc done right - don't remove it
@@ -30,6 +31,7 @@ __appropriate__(
     MaskAttackPadDatabase,
     CasiaSurfPadDatabase,
     CasiaFasdPadDatabase,
+    SwanPadDatabase,
 )
 
 __all__ = [_ for _ in dir() if not _.startswith('_')]
diff --git a/bob/pad/face/database/database.py b/bob/pad/face/database/database.py
index b6450063..d4323f00 100644
--- a/bob/pad/face/database/database.py
+++ b/bob/pad/face/database/database.py
@@ -17,6 +17,7 @@ def delayed_video_load(
     max_number_of_frames=None,
     step_size=None,
     get_transform=None,
+    keep_extension_for_annotation=False,
 ):
     if get_transform is None:
         def get_transform(x):
@@ -37,7 +38,9 @@ def delayed_video_load(
         )
         annotations, delayed_attributes = None, None
         if annotation_directory:
-            path = os.path.splitext(sample.filename)[0]
+            path = sample.filename
+            if not keep_extension_for_annotation:
+                path = os.path.splitext(sample.filename)[0]
             delayed_annotations = partial(
                 read_annotation_file,
                 file_name=f"{annotation_directory}:{path}.json",
@@ -63,6 +66,7 @@ def VideoPadSample(
     max_number_of_frames=None,
     step_size=None,
     get_transform=None,
+    keep_extension_for_annotation=False,
 ):
     return FunctionTransformer(
         delayed_video_load,
@@ -74,6 +78,7 @@ def VideoPadSample(
             max_number_of_frames=max_number_of_frames,
             step_size=step_size,
             get_transform=get_transform,
+            keep_extension_for_annotation=keep_extension_for_annotation,
         ),
     )
 
diff --git a/bob/pad/face/database/swan.py b/bob/pad/face/database/swan.py
new file mode 100644
index 00000000..5d0bb75e
--- /dev/null
+++ b/bob/pad/face/database/swan.py
@@ -0,0 +1,56 @@
+import logging
+
+from bob.extension import rc
+from bob.extension.download import get_file
+from bob.pad.base.database import FileListPadDatabase
+from bob.pad.face.database import VideoPadSample
+
+logger = logging.getLogger(__name__)
+
+
+def SwanPadDatabase(
+    protocol="pad_p2_face_f1",
+    selection_style=None,
+    max_number_of_frames=None,
+    step_size=None,
+    annotation_directory=None,
+    annotation_type=None,
+    fixed_positions=None,
+    **kwargs,
+):
+    name = "pad-face-swan.tar.gz"
+    dataset_protocols_path = get_file(
+        name,
+        [f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
+        cache_subdir="datasets",
+        # file_hash="a8e31cc3",protocols
+    )
+
+    if annotation_directory is None:
+        name = "annotations-swan-mtcnn.tar.xz"
+        annotation_directory = get_file(
+            name,
+            [f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
+            cache_subdir="annotations",
+            # file_hash="3ecbfa3c",
+        )
+        annotation_type = "eyes-center"
+
+    transformer = VideoPadSample(
+        original_directory=rc.get("bob.db.swan.directory"),
+        annotation_directory=annotation_directory,
+        selection_style=selection_style,
+        max_number_of_frames=max_number_of_frames,
+        step_size=step_size,
+        keep_extension_for_annotation=True,
+    )
+
+    database = FileListPadDatabase(
+        dataset_protocols_path,
+        protocol,
+        transformer=transformer,
+        **kwargs,
+    )
+    database.annotation_type = annotation_type
+    database.fixed_positions = fixed_positions
+    return database
diff --git a/bob/pad/face/test/test_databases.py b/bob/pad/face/test/test_databases.py
index bb4d5129..c1162b24 100644
--- a/bob/pad/face/test/test_databases.py
+++ b/bob/pad/face/test/test_databases.py
@@ -15,17 +15,26 @@ def test_replayattack():
         package_prefix="bob.pad.",
     )
 
-    assert database.protocols() == ['digitalphoto', 'grandtest', 'highdef', 'mobile', 'photo', 'print', 'smalltest', 'video']
+    assert database.protocols() == [
+        "digitalphoto",
+        "grandtest",
+        "highdef",
+        "mobile",
+        "photo",
+        "print",
+        "smalltest",
+        "video",
+    ]
     assert database.groups() == ["dev", "eval", "train"]
     assert len(database.samples(groups=["train", "dev", "eval"])) == 1200
     assert len(database.samples(groups=["train", "dev"])) == 720
     assert len(database.samples(groups=["train"])) == 360
-    assert len(database.samples(groups=["train", "dev", "eval"])) == 1200
     assert (
         len(database.samples(groups=["train", "dev", "eval"], purposes="real")) == 200
     )
     assert (
-        len(database.samples(groups=["train", "dev", "eval"], purposes="attack")) == 1000
+        len(database.samples(groups=["train", "dev", "eval"], purposes="attack"))
+        == 1000
     )
 
     sample = database.sort(database.samples())[0]
@@ -45,7 +54,6 @@ def test_replayattack():
         raise SkipTest(e)
 
 
-
 def test_replaymobile():
     database = bob.bio.base.load_resource(
         "replay-mobile",
@@ -59,7 +67,6 @@ def test_replaymobile():
     assert len(database.samples(groups=["train", "dev", "eval"])) == 1030
     assert len(database.samples(groups=["train", "dev"])) == 728
     assert len(database.samples(groups=["train"])) == 312
-    assert len(database.samples(groups=["train", "dev", "eval"])) == 1030
     assert (
         len(database.samples(groups=["train", "dev", "eval"], purposes="real")) == 390
     )
@@ -94,8 +101,7 @@ def test_maskattack():
     )
     # all real sequences: 2 sessions, 5 recordings for 17 individuals
     assert (
-        len(maskattack.samples(groups=["train", "dev", "eval"], purposes="real"))
-        == 170
+        len(maskattack.samples(groups=["train", "dev", "eval"], purposes="real")) == 170
     )
     # all attacks: 1 session, 5 recordings for 17 individuals
     assert (
@@ -139,6 +145,7 @@ def test_maskattack():
 #         == 57710
 #     )
 
+
 def test_casiasurf_color_protocol():
     casiasurf = bob.bio.base.load_resource(
         "casiasurf-color",
@@ -150,18 +157,14 @@ def test_casiasurf_color_protocol():
     assert len(casiasurf.samples(groups=["train"], purposes="attack")) == 20324
     assert len(casiasurf.samples(groups=("dev",), purposes=("real",))) == 2994
     assert len(casiasurf.samples(groups=("dev",), purposes=("attack",))) == 6614
-    assert (
-        len(casiasurf.samples(groups=("dev",), purposes=("real", "attack"))) == 9608
-    )
+    assert len(casiasurf.samples(groups=("dev",), purposes=("real", "attack"))) == 9608
     assert len(casiasurf.samples(groups=("eval",), purposes=("real",))) == 17458
     assert len(casiasurf.samples(groups=("eval",), purposes=("attack",))) == 40252
     assert (
-        len(casiasurf.samples(groups=("eval",), purposes=("real", "attack")))
-        == 57710
+        len(casiasurf.samples(groups=("eval",), purposes=("real", "attack"))) == 57710
     )
 
 
-
 def test_casia_fasd():
     casia_fasd = bob.bio.base.load_resource(
         "casiafasd",
@@ -177,3 +180,47 @@ def test_casia_fasd():
     assert len(casia_fasd.samples(groups="train")) == 180
     assert len(casia_fasd.samples(groups="dev")) == 60
     assert len(casia_fasd.samples(groups="eval")) == 360
+
+
+def test_swan():
+    database = bob.bio.base.load_resource(
+        "swan",
+        "database",
+        preferred_package="bob.pad.face",
+        package_prefix="bob.pad.",
+    )
+
+    assert database.protocols() == [
+        "pad_p2_face_f1",
+        "pad_p2_face_f2",
+        "pad_p2_face_f3",
+        "pad_p2_face_f4",
+        "pad_p2_face_f5",
+    ]
+    assert database.groups() == ["dev", "eval", "train"]
+    assert len(database.samples(groups=["train", "dev", "eval"])) == 5802
+    assert len(database.samples(groups=["train", "dev"])) == 2803
+    assert len(database.samples(groups=["train"])) == 2001
+    assert (
+        len(database.samples(groups=["train", "dev", "eval"], purposes="real")) == 3300
+    )
+    assert (
+        len(database.samples(groups=["train", "dev", "eval"], purposes="attack"))
+        == 2502
+    )
+
+    sample = database.sort(database.samples())[0]
+    try:
+        assert dict(sample.annotations["0"]) == {
+            "bottomright": [849, 564],
+            "leye": [511, 453],
+            "mouthleft": [709, 271],
+            "mouthright": [711, 445],
+            "nose": [590, 357],
+            "reye": [510, 265],
+            "topleft": [301, 169],
+        }
+        assert sample.data.shape == (20, 3, 720, 1280)
+        assert sample.data[0][0, 0, 0] == 87
+    except RuntimeError as e:
+        raise SkipTest(e)
diff --git a/doc/links.rst b/doc/links.rst
index cd652a07..09f5c2d7 100644
--- a/doc/links.rst
+++ b/doc/links.rst
@@ -14,3 +14,4 @@
 .. _dependencies: https://gitlab.idiap.ch/bob/bob/wikis/Dependencies
 .. _MIFS: http://www.antitza.com/makeup-datasets.html
 .. _CELEBA: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
+.. _Swan: https://www.idiap.ch/dataset/swan
diff --git a/setup.py b/setup.py
index 088ed775..10450c19 100644
--- a/setup.py
+++ b/setup.py
@@ -62,6 +62,7 @@ setup(
             "maskattack = bob.pad.face.config.maskattack:database",
             "casiasurf-color = bob.pad.face.config.casiasurf_color:database",
             "casiasurf = bob.pad.face.config.casiasurf:database",
+            "swan = bob.pad.face.config.swan:database",
         ],
         # registered configurations:
         "bob.pad.config": [
@@ -70,6 +71,9 @@ setup(
             "replay-mobile = bob.pad.face.config.replay_mobile",
             "casiafasd = bob.pad.face.config.casiafasd",
             "maskattack = bob.pad.face.config.maskattack",
+            "casiasurf-color = bob.pad.face.config.casiasurf_color",
+            "casiasurf = bob.pad.face.config.casiasurf",
+            "swan = bob.pad.face.config.swan",
             # LBPs
             "lbp = bob.pad.face.config.lbp_64",
             # quality measure
-- 
GitLab