Skip to content
Snippets Groups Projects
Commit 178e48d3 authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

unit tests, WIP

parent 28e747cc
No related branches found
No related tags found
1 merge request!86Batl loo protocols
Pipeline #26469 failed
...@@ -561,10 +561,11 @@ class BatlPadDatabase(PadDatabase): ...@@ -561,10 +561,11 @@ class BatlPadDatabase(PadDatabase):
protocol = 'nowig' if protocol == 'grandtest' else protocol protocol = 'nowig' if protocol == 'grandtest' else protocol
# Convert group names to low-level group names here. # Convert group names to low-level group names here.
groups = self.convert_names_to_lowlevel( groups = self.convert_names_to_lowlevel(groups, self.low_level_group_names, self.high_level_group_names)
groups, self.low_level_group_names, self.high_level_group_names)
if not isinstance(groups, list) and groups is not None: # if a single group is given make it a list
if not isinstance(groups, list) and groups is not None and not isinstance(groups,str): # if a single group is given make it a list
groups = list(groups) groups = list(groups)
if extra is not None and "join_train_dev" in extra: if extra is not None and "join_train_dev" in extra:
......
...@@ -219,3 +219,63 @@ def test_casiasurf(): ...@@ -219,3 +219,63 @@ def test_casiasurf():
raise SkipTest( raise SkipTest(
"The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'" "The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'"
% e) % e)
# Test the BATL database
def test_aggregated_db():
batl_db = bob.bio.base.load_resource(
'batl',
'database',
preferred_package='bob.pad.face',
package_prefix='bob.pad.')
try:
assert len(
batl_db.objects(groups=['train', 'dev', 'eval'])) == 1679
assert len(batl_db.objects(groups=['train', 'dev'])) == 1122
assert len(batl_db.objects(groups=['train'])) == 565
assert len(batl_db.objects(groups='train')) == 565
assert len(batl_db.objects(groups='dev')) == 557
assert len(batl_db.objects(groups='eval')) == 557
assert len(
batl_db.objects(
groups=['train', 'dev', 'eval'], protocol='grandtest')) == 1679
assert len(
batl_db.objects(
groups=['train', 'dev', 'eval'],
protocol='grandtest',
purposes='real')) == 347
assert len(
batl_db.objects(
groups=['train', 'dev', 'eval'],
protocol='grandtest',
purposes='attack')) == 1332
#tests for join_train_dev protocols
assert len(
batl_db.objects(
groups=['train', 'dev', 'eval'],
protocol='grandtest-color-50-join_train_dev')) == 1679
assert len(
batl_db.objects(
groups=['train', 'dev'], protocol='grandtest-color-50-join_train_dev')) == 1679
assert len(
batl_db.objects(groups='eval',
protocol='grandtest-color-50-join_train_dev')) == 557
# test for LOO_fakehead
assert len(
batl_db.objects(
groups=['train', 'dev', 'eval'],
protocol='grandtest-color-50-LOO_fakehead')) == 1149
assert len(
batl_db.objects(
groups=['train', 'dev'], protocol='grandtest-color-50-LOO_fakehead')) == 1017
assert len(
batl_db.objects(groups='eval',
protocol='grandtest-color-50-LOO_fakehead')) == 132
except IOError as e:
raise SkipTest(
"The database could not be queried; probably the db.sql3 file is missing. Here is the error: '%s'"
% e)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment