From 92072b7d396c717ca4a593e2f18d2e526d48ec0d Mon Sep 17 00:00:00 2001
From: Guillaume HEUSCH <guillaume.heusch@idiap.ch>
Date: Tue, 4 Feb 2020 16:02:38 +0100
Subject: [PATCH] [database] fixed HQWMCA high-level DB, added tests

---
 bob/pad/face/config/hqwmca.py       |  7 +++++++
 bob/pad/face/database/hqwmca.py     |  7 +++----
 bob/pad/face/test/test_databases.py | 31 +++++++++++++++++++++++++++++
 setup.py                            |  1 +
 4 files changed, 42 insertions(+), 4 deletions(-)
 create mode 100644 bob/pad/face/config/hqwmca.py

diff --git a/bob/pad/face/config/hqwmca.py b/bob/pad/face/config/hqwmca.py
new file mode 100644
index 00000000..911b3cd9
--- /dev/null
+++ b/bob/pad/face/config/hqwmca.py
@@ -0,0 +1,7 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+from bob.pad.face.database import HQWMCAPadDatabase 
+from bob.extension import rc
+
+database = HQWMCAPadDatabase()
diff --git a/bob/pad/face/database/hqwmca.py b/bob/pad/face/database/hqwmca.py
index 0df009e1..ec33246d 100644
--- a/bob/pad/face/database/hqwmca.py
+++ b/bob/pad/face/database/hqwmca.py
@@ -131,7 +131,7 @@ class HQWMCAPadDatabase(PadDatabase):
 
     def objects(self,
                 groups=None,
-                protocol='grand_test',
+                protocol=None,
                 purposes=None,
                 model_ids=None,
                 attack_types=None,
@@ -170,7 +170,7 @@ class HQWMCAPadDatabase(PadDatabase):
         if not isinstance(groups, list) and groups is not None and groups is not str: 
           groups = list(groups)
 
-        files = self.db.objects(protocol=self.protocol,
+        files = self.db.objects(protocol=protocol,
                                 groups=groups,
                                 purposes=purposes,
                                 attacks=attack_types,
@@ -180,10 +180,9 @@ class HQWMCAPadDatabase(PadDatabase):
 
 
     def annotations(self, file):
-        """ Generate / retrieve annotations
+        """ retrieve annotations
         
         This function will retrieve annotations (if exisiting and provided).
-        Otherwise, it will generate them using the MTCNN landmarks detector.  
         
         """
         if self.annotations_dir is not None:
diff --git a/bob/pad/face/test/test_databases.py b/bob/pad/face/test/test_databases.py
index 5b1c6f2e..3f005680 100644
--- a/bob/pad/face/test/test_databases.py
+++ b/bob/pad/face/test/test_databases.py
@@ -232,6 +232,37 @@ def test_casiasurf():
             % e)
 
 
+# Test the HQ-WMCA database
+@db_available('hqwmca')
+def test_hqwmca():
+    hqwmca = bob.bio.base.load_resource(
+        'hqwmca',
+        'database',
+        preferred_package='bob.pad.face',
+        package_prefix='bob.pad.')
+    try:
+        assert len(hqwmca.objects(protocol='grand_test')) == 2904 
+        assert len(hqwmca.objects(protocol='impersonation')) == 1843
+        assert len(hqwmca.objects(protocol='obfuscation')) == 1616
+
+        assert len(hqwmca.objects(protocol='grand_test', groups=('train',))) == 970
+        assert len(hqwmca.objects(protocol='grand_test', groups=('dev',))) == 968
+        assert len(hqwmca.objects(protocol='grand_test', groups=('eval',))) == 966
+
+        assert len(hqwmca.objects(protocol='impersonation', groups=('train',))) == 612
+        assert len(hqwmca.objects(protocol='impersonation', groups=('dev',))) == 609
+        assert len(hqwmca.objects(protocol='impersonation', groups=('eval',))) == 622
+
+        assert len(hqwmca.objects(protocol='obfuscation', groups=('train',))) == 586
+        assert len(hqwmca.objects(protocol='obfuscation', groups=('dev',))) == 504
+        assert len(hqwmca.objects(protocol='obfuscation', groups=('eval',))) == 526
+
+    except IOError as e:
+        raise SkipTest(
+            "The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'"
+            % e)
+
+
 @db_available('brsu')
 def test_brsu():
     brsu = bob.bio.base.load_resource(
diff --git a/setup.py b/setup.py
index 35e05f38..450551dd 100644
--- a/setup.py
+++ b/setup.py
@@ -79,6 +79,7 @@ setup(
             'casiasurf-color = bob.pad.face.config.casiasurf_color:database',
             'casiasurf = bob.pad.face.config.casiasurf:database',
             'brsu = bob.pad.face.config.brsu:database',
+            'hqwmca = bob.pad.face.config.hqwmca:database',
         ],
 
         # registered configurations:
-- 
GitLab