Commit 2baa9fbc authored by Amir Mohammadi's avatar Amir Mohammadi
Browse files

Add tests for the verification protocol

parent 52e7d915
Pipeline #7108 passed with stages
in 13 minutes and 28 seconds
......@@ -8,198 +8,265 @@ import os
import sys
import unittest
from .query import Database
from .verificationprotocol import Database as VerificationDatabase
from .models import *
authenticate_str = 'authenticate'
if sys.version_info[0] < 3:
authenticate_str = authenticate_str.encode('utf8')
authenticate_str = authenticate_str.encode('utf8')
enroll_str = 'enroll'
if sys.version_info[0] < 3:
enroll_str = enroll_str.encode('utf8')
enroll_str = enroll_str.encode('utf8')
def db_available(test):
"""Decorator for detecting if OpenCV/Python bindings are available"""
from bob.io.base.test_utils import datafile
from nose.plugins.skip import SkipTest
import functools
"""Decorator for detecting if OpenCV/Python bindings are available"""
from bob.io.base.test_utils import datafile
from nose.plugins.skip import SkipTest
import functools
@functools.wraps(test)
def wrapper(*args, **kwargs):
dbfile = datafile("db.sql3", __name__, None)
if os.path.exists(dbfile):
return test(*args, **kwargs)
else:
raise SkipTest("The database file '%s' is not available; did you forget to run 'bob_dbmanage.py %s create' ?" % (dbfile, 'replaymobile'))
@functools.wraps(test)
def wrapper(*args, **kwargs):
dbfile = datafile("db.sql3", __name__, None)
if os.path.exists(dbfile):
return test(*args, **kwargs)
else:
raise SkipTest("The database file '%s' is not available; did you forget to run 'bob_dbmanage.py %s create' ?" % (
dbfile, 'replaymobile'))
return wrapper
return wrapper
class ReplayMobileDatabaseTest(unittest.TestCase):
"""Performs various tests on the replay attack database."""
"""Performs various tests on the replay attack database."""
@db_available
def test01_queryRealAccesses(self):
@db_available
def test01_queryRealAccesses(self):
db = Database()
f = db.objects(cls='real')
# self.assertEqual(len(f), 400) # Still have to capture 1 users (client009)
self.assertEqual(len(f), 390)
for v in f[:10]: # only the 10 first...
self.assertTrue(isinstance(v.get_realaccess(), RealAccess))
o = v.get_realaccess()
self.assertEqual(o.purpose, authenticate_str)
db = Database()
f = db.objects(cls='real')
# self.assertEqual(len(f), 400) # Still have to capture 1 users
# (client009)
self.assertEqual(len(f), 390)
for v in f[:10]: # only the 10 first...
self.assertTrue(isinstance(v.get_realaccess(), RealAccess))
o = v.get_realaccess()
self.assertEqual(o.purpose, authenticate_str)
train = db.objects(cls='real', groups='train')
self.assertEqual(len(train), 120)
train = db.objects(cls='real', groups='train')
self.assertEqual(len(train), 120)
dev = db.objects(cls='real', groups='devel')
# self.assertEqual(len(dev), 120) # Still have to capture 1 users (client009)
self.assertEqual(len(dev), 160)
dev = db.objects(cls='real', groups='devel')
# self.assertEqual(len(dev), 120) # Still have to capture 1 users
# (client009)
self.assertEqual(len(dev), 160)
test = db.objects(cls='real', groups='test')
self.assertEqual(len(test), 110)
test = db.objects(cls='real', groups='test')
self.assertEqual(len(test), 110)
# tests train, devel and test files are distinct
s = set(train + dev + test)
# self.assertEqual(len(s), 400) # Still have to capture 1 users (client009)
self.assertEqual(len(s), 390)
# tests train, devel and test files are distinct
s = set(train + dev + test)
# self.assertEqual(len(s), 400) # Still have to capture 1 users
# (client009)
self.assertEqual(len(s), 390)
@db_available
def queryAttackType(self, protocol, N):
db = Database()
f = db.objects(cls='attack', protocol=protocol)
@db_available
def queryAttackType(self, protocol, N):
db = Database()
f = db.objects(cls='attack', protocol=protocol)
self.assertEqual(len(f), N)
for k in f[:10]: # only the 10 first...
k.get_attack()
self.assertRaises(RuntimeError, k.get_realaccess)
self.assertEqual(len(f), N)
for k in f[:10]: # only the 10 first...
k.get_attack()
self.assertRaises(RuntimeError, k.get_realaccess)
train = db.objects(cls='attack', groups='train', protocol=protocol)
self.assertEqual(len(train), int(round(0.3 * N)))
train = db.objects(cls='attack', groups='train', protocol=protocol)
self.assertEqual(len(train), int(round(0.3 * N)))
dev = db.objects(cls='attack', groups='devel', protocol=protocol)
self.assertEqual(len(dev), int(round(0.4 * N)))
dev = db.objects(cls='attack', groups='devel', protocol=protocol)
self.assertEqual(len(dev), int(round(0.4 * N)))
test = db.objects(cls='attack', groups='test', protocol=protocol)
self.assertEqual(len(test), int(round(0.3 * N)))
test = db.objects(cls='attack', groups='test', protocol=protocol)
self.assertEqual(len(test), int(round(0.3 * N)))
# tests train, devel and test files are distinct
s = set(train + dev + test)
self.assertEqual(len(s), N)
# tests train, devel and test files are distinct
s = set(train + dev + test)
self.assertEqual(len(s), N)
@db_available
def test02_queryAttacks(self):
@db_available
def test02_queryAttacks(self):
self.queryAttackType('grandtest', 640)
self.queryAttackType('grandtest', 640)
@db_available
def test03_queryPrintAttacks(self):
@db_available
def test03_queryPrintAttacks(self):
self.queryAttackType('print', 320)
self.queryAttackType('print', 320)
@db_available
def test04_queryMattescreenAttacks(self):
@db_available
def test04_queryMattescreenAttacks(self):
self.queryAttackType('mattescreen', 320)
self.queryAttackType('mattescreen', 320)
@db_available
def test05_queryEnrollments(self):
@db_available
def test05_queryEnrollments(self):
db = Database()
f = db.objects(cls='enroll')
self.assertEqual(len(f), 160) # 40 clients, 2 conditions 2 devices
for v in f:
self.assertEqual(v.get_realaccess().purpose, enroll_str)
db = Database()
f = db.objects(cls='enroll')
self.assertEqual(len(f), 160) # 40 clients, 2 conditions 2 devices
for v in f:
self.assertEqual(v.get_realaccess().purpose, enroll_str)
@db_available
def test06_queryClients(self):
@db_available
def test06_queryClients(self):
db = Database()
f = db.clients()
self.assertEqual(len(f), 40) # 40 clients
self.assertTrue(db.has_client_id(3))
self.assertTrue(db.has_client_id(40))
self.assertTrue(db.has_client_id(6))
self.assertTrue(db.has_client_id(21))
self.assertTrue(db.has_client_id(30))
self.assertFalse(db.has_client_id(0))
self.assertFalse(db.has_client_id(50))
self.assertFalse(db.has_client_id(60))
self.assertFalse(db.has_client_id(55))
@db_available
def test7_queryfacefile(self):
db = Database()
o = db.objects(clients=(1,))[0]
o.facefile()
@db_available
def test8_manage_files(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(main('replaymobile files'.split()), 0)
@db_available
def test9_manage_dumplist_1(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(main('replaymobile dumplist --self-test'.split()), 0)
@db_available
def test10_manage_dumplist_2(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(main('replaymobile dumplist --class=attack --group=devel --support=hand --protocol=print --self-test'.split()), 0)
@db_available
def test11_manage_dumplist_client(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(main('replaymobile dumplist --client 1 --self-test'.split()), 0)
@db_available
def test12_manage_checkfiles(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(main('replaymobile checkfiles --self-test'.split()), 0)
@db_available
def test13_queryRealAccesses(self):
db = Database()
trainClients = ['001', '003', '008', '011', '012', '016', '018', '020', '026', '034', '037', '038']
develClients = ['005', '006', '013', '014', '015', '023', '024', '025', '028', '029', '031', '032', '035', '036', '039', '040']
testClients = ['002', '004', '007', '009', '010', '017', '019', '021', '022', '027', '030', '033']
f = db.objects(cls='real')
self.assertEqual(len(f), 390)
train = db.objects(cls='real', groups='train')
self.assertEqual(len(train), 120)
for cl in train:
clFilename = cl.videofile("")
clientPos = clFilename.find('client')
clientId = clFilename[clientPos + 6:clientPos + 9]
self.assertTrue(clientId in trainClients)
dev = db.objects(cls='real', groups='devel')
self.assertEqual(len(dev), 160)
for cl in dev:
clFilename = cl.videofile("")
clientPos = clFilename.find('client')
clientId = clFilename[clientPos + 6:clientPos + 9]
self.assertTrue(clientId in develClients)
test = db.objects(cls='real', groups='test')
self.assertEqual(len(test), 110)
for cl in test:
clFilename = cl.videofile("")
clientPos = clFilename.find('client')
clientId = clFilename[clientPos + 6:clientPos + 9]
self.assertTrue(clientId in testClients)
db = Database()
f = db.clients()
self.assertEqual(len(f), 40) # 40 clients
self.assertTrue(db.has_client_id(3))
self.assertTrue(db.has_client_id(40))
self.assertTrue(db.has_client_id(6))
self.assertTrue(db.has_client_id(21))
self.assertTrue(db.has_client_id(30))
self.assertFalse(db.has_client_id(0))
self.assertFalse(db.has_client_id(50))
self.assertFalse(db.has_client_id(60))
self.assertFalse(db.has_client_id(55))
@db_available
def test7_queryfacefile(self):
db = Database()
o = db.objects(clients=(1,))[0]
o.facefile()
@db_available
def test8_manage_files(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(main('replaymobile files'.split()), 0)
@db_available
def test9_manage_dumplist_1(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(main('replaymobile dumplist --self-test'.split()), 0)
@db_available
def test10_manage_dumplist_2(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(main(
'replaymobile dumplist --class=attack --group=devel --support=hand --protocol=print --self-test'.split()), 0)
@db_available
def test11_manage_dumplist_client(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(
main('replaymobile dumplist --client 1 --self-test'.split()), 0)
@db_available
def test12_manage_checkfiles(self):
from bob.db.base.script.dbmanage import main
self.assertEqual(
main('replaymobile checkfiles --self-test'.split()), 0)
@db_available
def test13_queryRealAccesses(self):
db = Database()
trainClients = ['001', '003', '008', '011', '012',
'016', '018', '020', '026', '034', '037', '038']
develClients = ['005', '006', '013', '014', '015', '023', '024',
'025', '028', '029', '031', '032', '035', '036',
'039', '040']
testClients = ['002', '004', '007', '009', '010',
'017', '019', '021', '022', '027', '030', '033']
f = db.objects(cls='real')
self.assertEqual(len(f), 390)
train = db.objects(cls='real', groups='train')
self.assertEqual(len(train), 120)
for cl in train:
clFilename = cl.videofile("")
clientPos = clFilename.find('client')
clientId = clFilename[clientPos + 6:clientPos + 9]
self.assertTrue(clientId in trainClients)
dev = db.objects(cls='real', groups='devel')
self.assertEqual(len(dev), 160)
for cl in dev:
clFilename = cl.videofile("")
clientPos = clFilename.find('client')
clientId = clFilename[clientPos + 6:clientPos + 9]
self.assertTrue(clientId in develClients)
test = db.objects(cls='real', groups='test')
self.assertEqual(len(test), 110)
for cl in test:
clFilename = cl.videofile("")
clientPos = clFilename.find('client')
clientId = clFilename[clientPos + 6:clientPos + 9]
self.assertTrue(clientId in testClients)
def test_verification_protocol():
nframes = 10
db = VerificationDatabase(max_number_of_frames=nframes)
# default is licit protocol
files = db.objects()
assert len(files) == 550 * nframes
clients = list(set(f.client_id for f in files))
model_ids = db.model_ids_with_protocol()
assert len(clients) == 40
assert set(clients) == set(model_ids)
# make sure all files are real
assert all(f._f.is_real() for f in files)
for client in clients:
files = db.objects(model_ids=client, purposes='enroll')
assert len(files) == nframes * 4
for f in files:
# make sure to enroll with the same id
assert client == f.client_id
# make sure they are all real
assert f._f.is_real()
# check probe files
files = db.objects(model_ids=client, purposes='probe')
# make sure to probe against all clients
assert len(files) == nframes * 39 * 10 # nframes frames 39 clients
for f in files:
assert f._f.is_real()
# check the spoof protocol
files = db.objects(protocol='grandtest-spoof')
assert len(files) == 800 * nframes
model_ids = db.model_ids_with_protocol(protocol='grandtest-spoof')
assert len(model_ids) == 40
# all enroll files: real, laptop, same id
# all probe files: attack, same id and attack client_id, all qualities,
# against the same id
for client in clients:
files = db.objects(protocol='grandtest-spoof',
model_ids=client, purposes='enroll')
assert len(files) == nframes * 4
for f in files:
# make sure to enroll with the same id
assert client == f.client_id
# make sure they are from laptop
assert f._f.is_real()
# check probe files
files = db.objects(protocol='grandtest-spoof',
model_ids=client, purposes='probe')
# make sure to probe against only the same client
assert len(files) == nframes * 16
for f in files:
assert not f._f.is_real()
assert 'attack' in f.client_id
assert client == f._f.client_id
......@@ -13,17 +13,28 @@ from .query import Database as LDatabase
def selected_indices(total_number_of_indices, desired_number_of_indices=None):
"""Returns a list of indices that will contain exactly the number of desired indices (or the number of total items in the list, if this is smaller).
These indices are selected such that they are evenly spread over the whole sequence."""
if desired_number_of_indices is None or desired_number_of_indices >= total_number_of_indices or desired_number_of_indices < 0:
return range(total_number_of_indices)
increase = float(total_number_of_indices) / float(desired_number_of_indices)
# generate a regular quasi-random index list
return [int((i + .5) * increase) for i in range(desired_number_of_indices)]
"""
Returns a list of indices that will contain exactly the number of desired
indices (or the number of total items in the list, if this is smaller).
These indices are selected such that they are evenly spread over the whole
sequence.
"""
if desired_number_of_indices is None or \
desired_number_of_indices >= total_number_of_indices or \
desired_number_of_indices < 0:
return range(total_number_of_indices)
increase = float(total_number_of_indices) / \
float(desired_number_of_indices)
# generate a regular quasi-random index list
return [int((i + .5) * increase) for i in range(desired_number_of_indices)]
class File(BaseFile):
"""Replay Mobile low-level file used for vulnerability analysis in face recognition"""
"""
Replay Mobile low-level file used for vulnerability analysis in face
recognition
"""
def __init__(self, f, framen=None):
self._f = f
......@@ -78,19 +89,26 @@ class Database(BaseDatabase):
def groups(self):
return self.convert_names_to_highlevel(
self._db.groups(), self.low_level_group_names, self.high_level_group_names)
self._db.groups(), self.low_level_group_names,
self.high_level_group_names)
def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs):
# since the low-level API does not support verification straight-forward-ly, we improvise.
files = self.objects(groups=groups, protocol=protocol, purposes='enroll', **kwargs)
# since the low-level API does not support verification
# straight-forward-ly, we improvise.
files = self.objects(groups=groups, protocol=protocol,
purposes='enroll', **kwargs)
return sorted(set(f.client_id for f in files))
def objects(self, groups=None, protocol=None, purposes=None, model_ids=None, **kwargs):
def objects(self, groups=None, protocol=None, purposes=None,
model_ids=None, **kwargs):
if protocol == '.':
protocol = None
protocol = self.check_parameter_for_validity(protocol, "protocol", self.protocol_names(), 'grandtest-licit')
groups = self.check_parameters_for_validity(groups, "group", self.groups(), self.groups())
purposes = self.check_parameters_for_validity(purposes, "purpose", ('enroll', 'probe'), ('enroll', 'probe'))
protocol = self.check_parameter_for_validity(
protocol, "protocol", self.protocol_names(), 'grandtest-licit')
groups = self.check_parameters_for_validity(
groups, "group", self.groups(), self.groups())
purposes = self.check_parameters_for_validity(
purposes, "purpose", ('enroll', 'probe'), ('enroll', 'probe'))
purposes = list(purposes)
groups = self.convert_names_to_lowlevel(
groups, self.low_level_group_names, self.high_level_group_names)
......@@ -107,28 +125,32 @@ class Database(BaseDatabase):
purposes.remove('probe')
purposes.append('real')
if len(purposes) == 1:
# making the model_ids to None will return all clients which make
# the impostor data also available.
# making the model_ids to None will return all clients
# which make the impostor data also available.
model_ids = None
elif model_ids:
raise NotImplementedError(
'Currently returning both enroll and probe for specific '
'client(s) in the licit protocol is not supported. '
'Please specify one purpose only.')
'Currently returning both enroll and probe for '
'specific client(s) in the licit protocol is not '
'supported. Please specify one purpose only.')
elif '-spoof' in protocol:
protocol = protocol.replace('-spoof', '')
# you need to replace probe with attack and real for the spoof protocols.
# You can add the real here also to create positives scores also
# but usually you get these scores when you run the licit protocol
# you need to replace probe with attack and real for the spoof
# protocols. You can add the real here also to create positives
# scores also but usually you get these scores when you run the
# licit protocol
if 'probe' in purposes:
purposes.remove('probe')
purposes.append('attack')
# now, query the actual Replay database
objects = self._db.objects(groups=groups, protocol=protocol, cls=purposes, clients=model_ids, **kwargs)
objects = self._db.objects(
groups=groups, protocol=protocol, cls=purposes, clients=model_ids,
**kwargs)
# make sure to return File representation of a file, not the database one
# also make sure you replace client ids with attack
# make sure to return File representation of a file, not the database
# one also make sure you replace client ids with attack
retval = []
for f in objects:
for i in self.indices:
......@@ -136,6 +158,8 @@ class Database(BaseDatabase):
retval.append(File(f, i))
else:
temp = File(f, i)
temp.client_id = 'attack'
attack = f.get_attack()
temp.client_id = 'attack/{}'.format(
attack.attack_device, attack.attack_support)
retval.append(temp)
return retval
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment