From 178e48d383dd6769d18e382ed315ab68d5e35cd7 Mon Sep 17 00:00:00 2001
From: ageorge <anjith.george@idiap.ch>
Date: Thu, 31 Jan 2019 14:59:52 +0100
Subject: [PATCH] unit tests, WIP

---
 bob/pad/face/database/batl.py       |  7 ++--
 bob/pad/face/test/test_databases.py | 60 +++++++++++++++++++++++++++++
 2 files changed, 64 insertions(+), 3 deletions(-)

diff --git a/bob/pad/face/database/batl.py b/bob/pad/face/database/batl.py
index 400fe8e3..850c276e 100644
--- a/bob/pad/face/database/batl.py
+++ b/bob/pad/face/database/batl.py
@@ -561,10 +561,11 @@ class BatlPadDatabase(PadDatabase):
         protocol = 'nowig' if protocol == 'grandtest' else protocol
 
         # Convert group names to low-level group names here.
-        groups = self.convert_names_to_lowlevel(
-            groups, self.low_level_group_names, self.high_level_group_names)
+        groups = self.convert_names_to_lowlevel(groups, self.low_level_group_names, self.high_level_group_names)
 
-        if not isinstance(groups, list) and groups is not None:  # if a single group is given make it a list
+            
+
+        if not isinstance(groups, list) and groups is not None and not isinstance(groups,str):  # if a single group is given make it a list
             groups = list(groups)
 
         if extra is not None and "join_train_dev" in extra:
diff --git a/bob/pad/face/test/test_databases.py b/bob/pad/face/test/test_databases.py
index 3ed90510..eebd65c2 100644
--- a/bob/pad/face/test/test_databases.py
+++ b/bob/pad/face/test/test_databases.py
@@ -219,3 +219,63 @@ def test_casiasurf():
         raise SkipTest(
             "The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'"
             % e)
+
+# Test the BATL database
+def test_aggregated_db():
+    batl_db = bob.bio.base.load_resource(
+        'batl',
+        'database',
+        preferred_package='bob.pad.face',
+        package_prefix='bob.pad.')
+    try:
+
+        assert len(
+            batl_db.objects(groups=['train', 'dev', 'eval'])) == 1679
+        assert len(batl_db.objects(groups=['train', 'dev'])) == 1122
+        assert len(batl_db.objects(groups=['train'])) == 565
+
+        assert len(batl_db.objects(groups='train')) == 565
+        assert len(batl_db.objects(groups='dev')) == 557
+        assert len(batl_db.objects(groups='eval')) == 557
+
+        assert len(
+            batl_db.objects(
+                groups=['train', 'dev', 'eval'], protocol='grandtest')) == 1679
+        assert len(
+            batl_db.objects(
+                groups=['train', 'dev', 'eval'],
+                protocol='grandtest',
+                purposes='real')) == 347
+        assert len(
+            batl_db.objects(
+                groups=['train', 'dev', 'eval'],
+                protocol='grandtest',
+                purposes='attack')) == 1332
+        #tests for join_train_dev protocols
+        assert len(
+            batl_db.objects(
+                groups=['train', 'dev', 'eval'],
+                protocol='grandtest-color-50-join_train_dev')) == 1679
+        assert len(
+            batl_db.objects(
+                groups=['train', 'dev'], protocol='grandtest-color-50-join_train_dev')) == 1679
+        assert len(
+            batl_db.objects(groups='eval',
+                                  protocol='grandtest-color-50-join_train_dev')) == 557
+        # test for LOO_fakehead
+        assert len(
+            batl_db.objects(
+                groups=['train', 'dev', 'eval'],
+                protocol='grandtest-color-50-LOO_fakehead')) == 1149
+        assert len(
+            batl_db.objects(
+                groups=['train', 'dev'], protocol='grandtest-color-50-LOO_fakehead')) == 1017
+        assert len(
+            batl_db.objects(groups='eval',
+                                  protocol='grandtest-color-50-LOO_fakehead')) == 132
+
+
+    except IOError as e:
+        raise SkipTest(
+            "The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'"
+            % e)
\ No newline at end of file
-- 
GitLab