Skip to content
Snippets Groups Projects
Commit ad1160e1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Simplified API to the Replay Attack Database

parent 59765985
No related branches found
No related tags found
No related merge requests found
...@@ -25,5 +25,6 @@ on your references: ...@@ -25,5 +25,6 @@ on your references:
""" """
from .query import Database from .query import Database
from .models import Client, File, Protocol, RealAccess, Attack
__all__ = ['Database'] __all__ = dir()
...@@ -18,9 +18,7 @@ def checkfiles(args): ...@@ -18,9 +18,7 @@ def checkfiles(args):
from .query import Database from .query import Database
db = Database() db = Database()
r = db.files( r = db.objects(
directory=args.directory,
extension=args.extension,
protocol=args.protocol, protocol=args.protocol,
support=args.support, support=args.support,
groups=args.group, groups=args.group,
...@@ -30,11 +28,13 @@ def checkfiles(args): ...@@ -30,11 +28,13 @@ def checkfiles(args):
) )
# go through all files, check if they are available on the filesystem # go through all files, check if they are available on the filesystem
good = {} good = []
bad = {} bad = []
for id, f in r.items(): for f in r:
if os.path.exists(f): good[id] = f if os.path.exists(f.make_path(args.directory, args.extension)):
else: bad[id] = f good.append(f)
else:
bad.append(f)
# report # report
output = sys.stdout output = sys.stdout
...@@ -43,8 +43,8 @@ def checkfiles(args): ...@@ -43,8 +43,8 @@ def checkfiles(args):
output = null() output = null()
if bad: if bad:
for id, f in bad.items(): for f in bad:
output.write('Cannot find file "%s"\n' % (f,)) output.write('Cannot find file "%s"\n' % (f.make_path(args.directory, args.extension),))
output.write('%d files (out of %d) were not found at "%s"\n' % \ output.write('%d files (out of %d) were not found at "%s"\n' % \
(len(bad), len(r), args.directory)) (len(bad), len(r), args.directory))
...@@ -65,8 +65,8 @@ def add_command(subparsers): ...@@ -65,8 +65,8 @@ def add_command(subparsers):
protocols = ('waiting','for','database','creation') protocols = ('waiting','for','database','creation')
clients = tuple() clients = tuple()
else: else:
protocols = db.protocols() protocols = [k.name for k in db.protos()]
clients = db.clients() clients = [k.id for k in db.clients()]
parser.add_argument('-d', '--directory', dest="directory", default='', help="if given, this path will be prepended to every entry checked (defaults to '%(default)s')") parser.add_argument('-d', '--directory', dest="directory", default='', help="if given, this path will be prepended to every entry checked (defaults to '%(default)s')")
parser.add_argument('-e', '--extension', dest="extension", default='', help="if given, this extension will be appended to every entry checked (defaults to '%(default)s')") parser.add_argument('-e', '--extension', dest="extension", default='', help="if given, this extension will be appended to every entry checked (defaults to '%(default)s')")
......
...@@ -18,9 +18,7 @@ def dumplist(args): ...@@ -18,9 +18,7 @@ def dumplist(args):
from .query import Database from .query import Database
db = Database() db = Database()
r = db.files( r = db.objects(
directory=args.directory,
extension=args.extension,
protocol=args.protocol, protocol=args.protocol,
support=args.support, support=args.support,
groups=args.group, groups=args.group,
...@@ -34,8 +32,8 @@ def dumplist(args): ...@@ -34,8 +32,8 @@ def dumplist(args):
from bob.db.utils import null from bob.db.utils import null
output = null() output = null()
for id, f in r.items(): for f in r:
output.write('%s\n' % (f,)) output.write('%s\n' % (f.make_path(args.directory, args.extension),))
return 0 return 0
...@@ -54,8 +52,8 @@ def add_command(subparsers): ...@@ -54,8 +52,8 @@ def add_command(subparsers):
protocols = ('waiting','for','database','creation') protocols = ('waiting','for','database','creation')
clients = tuple() clients = tuple()
else: else:
protocols = db.protocols() protocols = [k.name for k in db.protos()]
clients = db.clients() clients = [k.id for k in db.clients()]
parser.add_argument('-d', '--directory', dest="directory", default='', help="if given, this path will be prepended to every entry returned (defaults to '%(default)s')") parser.add_argument('-d', '--directory', dest="directory", default='', help="if given, this path will be prepended to every entry returned (defaults to '%(default)s')")
parser.add_argument('-e', '--extension', dest="extension", default='', help="if given, this extension will be appended to every entry returned (defaults to '%(default)s')") parser.add_argument('-e', '--extension', dest="extension", default='', help="if given, this extension will be appended to every entry returned (defaults to '%(default)s')")
......
...@@ -6,40 +6,61 @@ ...@@ -6,40 +6,61 @@
"""Table models and functionality for the Replay Attack DB. """Table models and functionality for the Replay Attack DB.
""" """
import os
from sqlalchemy import Table, Column, Integer, String, ForeignKey from sqlalchemy import Table, Column, Integer, String, ForeignKey
from bob.db.sqlalchemy_migration import Enum, relationship from bob.db.sqlalchemy_migration import Enum, relationship
import bob.db.utils
from sqlalchemy.orm import backref from sqlalchemy.orm import backref
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
import numpy
Base = declarative_base() Base = declarative_base()
class Client(Base): class Client(Base):
"""Database clients, marked by an integer identifier and the set they belong
to"""
__tablename__ = 'client' __tablename__ = 'client'
set_choices = ('train', 'devel', 'test') set_choices = ('train', 'devel', 'test')
"""Possible groups to which clients may belong to"""
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
"""Key identifier for clients"""
set = Column(Enum(*set_choices)) set = Column(Enum(*set_choices))
"""Set to which this client belongs to"""
def __init__(self, id, set): def __init__(self, id, set):
self.id = id self.id = id
self.set = set self.set = set
def __repr__(self): def __repr__(self):
return "<Client('%s', '%s')>" % (self.id, self.set) return "Client('%s', '%s')" % (self.id, self.set)
class File(Base): class File(Base):
"""Generic file container"""
__tablename__ = 'file' __tablename__ = 'file'
light_choices = ('controlled', 'adverse') light_choices = ('controlled', 'adverse')
"""List of illumination conditions for data taking"""
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
"""Key identifier for files"""
client_id = Column(Integer, ForeignKey('client.id')) # for SQL client_id = Column(Integer, ForeignKey('client.id')) # for SQL
"""The client identifier to which this file is bound to"""
path = Column(String(100), unique=True) path = Column(String(100), unique=True)
"""The (unique) path to this file inside the database"""
light = Column(Enum(*light_choices)) light = Column(Enum(*light_choices))
"""The illumination condition in which the data for this file was taken"""
# for Python # for Python
client = relationship(Client, backref=backref('files', order_by=id)) client = relationship(Client, backref=backref('files', order_by=id))
"""A direct link to the client object that this file belongs to"""
def __init__(self, client, path, light): def __init__(self, client, path, light):
self.client = client self.client = client
...@@ -47,7 +68,109 @@ class File(Base): ...@@ -47,7 +68,109 @@ class File(Base):
self.light = light self.light = light
def __repr__(self): def __repr__(self):
print "<File('%s')>" % self.path return "File('%s')" % self.path
def make_path(self, directory=None, extension=None):
"""Wraps the current path so that a complete path is formed
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
extension
An optional extension that will be suffixed to the returned filename. The
extension normally includes the leading ``.`` character as in ``.jpg`` or
``.hdf5``.
Returns a string containing the newly generated file path.
"""
if not directory: directory = ''
if not extension: extension = ''
return os.path.join(directory, self.path + extension)
def facefile(self, directory=None):
"""Returns the path to the companion face bounding-box file
Keyword parameters:
directory
An optional directory name that will be prefixed to the returned result.
Returns a string containing the face file path.
"""
if not directory: directory = ''
directory = os.path.join(directory, 'face-locations')
return self.make_path(directory, '.face')
def bbx(self, directory=None):
"""Reads the file containing the face locations for the frames in the
current video
Keyword parameters:
directory
A directory name that will be prepended to the final filepaths where the
face bounding boxes are located, if not on the current directory.
Returns:
A :py:class:`numpy.ndarray` containing information about the located
faces in the videos. Each row of the :py:class:`numpy.ndarray`
corresponds for one frame. The five columns of the
:py:class:`numpy.ndarray` are (all integers):
* Frame number (int)
* Bounding box top-left X coordinate (int)
* Bounding box top-left Y coordinate (int)
* Bounding box width (int)
* Bounding box height (int)
Note that **not** all the frames may contain detected faces.
"""
return numpy.loadtxt(self.facefile(directory), dtype=int)
def is_real(self):
"""Returns True if this file belongs to a real access, False otherwise"""
return bool(self.realaccess)
def get_realaccess(self):
"""Returns the real-access object equivalent to this file or raise"""
if len(self.realaccess) == 0:
raise RuntimeError, "%s is not a real-access" % self
return self.realaccess[0]
def get_attack(self):
"""Returns the attack object equivalent to this file or raise"""
if len(self.attack) == 0:
raise RuntimeError, "%s is not an attack" % self
return self.attack[0]
def save(self, data, directory=None, extension='.hdf5'):
"""Saves the input data at the specified location and using the given
extension.
Keyword parameters:
data
The data blob to be saved (normally a :py:class:`numpy.ndarray`).
directory
If not empty or None, this directory is prefixed to the final file
destination
extension
The extension of the filename - this will control the type of output and
the codec for saving the input blob.
"""
path = self.make_path(directory, extension)
bob.utils.makedirs_safe(os.path.dirname(path))
bob.io.save(data, path)
# Intermediate mapping from RealAccess's to Protocol's # Intermediate mapping from RealAccess's to Protocol's
realaccesses_protocols = Table('realaccesses_protocols', Base.metadata, realaccesses_protocols = Table('realaccesses_protocols', Base.metadata,
...@@ -62,31 +185,49 @@ attacks_protocols = Table('attacks_protocols', Base.metadata, ...@@ -62,31 +185,49 @@ attacks_protocols = Table('attacks_protocols', Base.metadata,
) )
class Protocol(Base): class Protocol(Base):
"""Replay attack protocol"""
__tablename__ = 'protocol' __tablename__ = 'protocol'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
"""Unique identifier for the protocol (integer)"""
name = Column(String(20), unique=True) name = Column(String(20), unique=True)
"""Protocol name"""
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
def __repr__(self): def __repr__(self):
return "<Protocol('%s')>" % (self.name,) return "Protocol('%s')" % (self.name,)
class RealAccess(Base): class RealAccess(Base):
"""Defines Real-Accesses (licit attempts to authenticate)"""
__tablename__ = 'realaccess' __tablename__ = 'realaccess'
purpose_choices = ('authenticate', 'enroll') purpose_choices = ('authenticate', 'enroll')
"""Types of purpose for this video"""
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
"""Unique identifier for this real-access object"""
file_id = Column(Integer, ForeignKey('file.id')) # for SQL file_id = Column(Integer, ForeignKey('file.id')) # for SQL
"""The file identifier the current real-access is bound to"""
purpose = Column(Enum(*purpose_choices)) purpose = Column(Enum(*purpose_choices))
"""The purpose of this video"""
take = Column(Integer) take = Column(Integer)
"""Take number"""
# for Python # for Python
file = relationship(File, backref=backref('realaccess', order_by=id)) file = relationship(File, backref=backref('realaccess', order_by=id))
"""A direct link to the :py:class:`.File` object this real-access belongs to"""
protocols = relationship("Protocol", secondary=realaccesses_protocols, protocols = relationship("Protocol", secondary=realaccesses_protocols,
backref='realaccesses') backref='realaccesses')
"""A direct link to the protocols this file is linked to"""
def __init__(self, file, purpose, take): def __init__(self, file, purpose, take):
self.file = file self.file = file
...@@ -94,27 +235,50 @@ class RealAccess(Base): ...@@ -94,27 +235,50 @@ class RealAccess(Base):
self.take = take self.take = take
def __repr__(self): def __repr__(self):
return "<RealAccess('%s')>" % (self.file.path) return "RealAccess('%s')" % (self.file.path)
class Attack(Base): class Attack(Base):
"""Defines Spoofing Attacks (illicit attempts to authenticate)"""
__tablename__ = 'attack' __tablename__ = 'attack'
attack_support_choices = ('fixed', 'hand') attack_support_choices = ('fixed', 'hand')
"""Types of attack support"""
attack_device_choices = ('print', 'mobile', 'highdef', 'mask') attack_device_choices = ('print', 'mobile', 'highdef', 'mask')
"""Types of devices used for spoofing"""
sample_type_choices = ('video', 'photo') sample_type_choices = ('video', 'photo')
"""Original sample type"""
sample_device_choices = ('mobile', 'highdef') sample_device_choices = ('mobile', 'highdef')
"""Original sample device"""
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
"""Unique identifier for this attack"""
file_id = Column(Integer, ForeignKey('file.id')) # for SQL file_id = Column(Integer, ForeignKey('file.id')) # for SQL
"""The file identifier this attack is linked to"""
attack_support = Column(Enum(*attack_support_choices)) attack_support = Column(Enum(*attack_support_choices))
"""The attack support"""
attack_device = Column(Enum(*attack_device_choices)) attack_device = Column(Enum(*attack_device_choices))
"""The attack device"""
sample_type = Column(Enum(*sample_type_choices)) sample_type = Column(Enum(*sample_type_choices))
"""The attack sample type"""
sample_device = Column(Enum(*sample_device_choices)) sample_device = Column(Enum(*sample_device_choices))
"""The attack sample device"""
# for Python # for Python
file = relationship(File, backref=backref('attack', order_by=id)) file = relationship(File, backref=backref('attack', order_by=id))
"""A direct link to the :py:class:`.File` object bound to this attack"""
protocols = relationship("Protocol", secondary=attacks_protocols, protocols = relationship("Protocol", secondary=attacks_protocols,
backref='attacks') backref='attacks')
"""A direct link to the protocols this file is linked to"""
def __init__(self, file, attack_support, attack_device, sample_type, sample_device): def __init__(self, file, attack_support, attack_device, sample_type, sample_device):
self.file = file self.file = file
......
This diff is collapsed.
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
import os, sys import os, sys
import unittest import unittest
from .query import Database from .query import Database
from .models import *
class ReplayDatabaseTest(unittest.TestCase): class ReplayDatabaseTest(unittest.TestCase):
"""Performs various tests on the replay attack database.""" """Performs various tests on the replay attack database."""
...@@ -30,46 +31,47 @@ class ReplayDatabaseTest(unittest.TestCase): ...@@ -30,46 +31,47 @@ class ReplayDatabaseTest(unittest.TestCase):
def test01_queryRealAccesses(self): def test01_queryRealAccesses(self):
db = Database() db = Database()
f = db.files(cls='real') f = db.objects(cls='real')
self.assertEqual(len(set(f.values())), 200) #200 unique auth sessions self.assertEqual(len(f), 200) #200 unique auth sessions
for k,v in f.items(): for v in f[:10]: #only the 10 first...
self.assertTrue( (v.find('authenticate') != -1) ) self.assertTrue(isinstance(v.get_realaccess(), RealAccess))
self.assertTrue( (v.find('real') != -1) ) o = v.get_realaccess()
self.assertTrue( (v.find('webcam') != -1) ) self.assertEqual(o.purpose, u'authenticate')
train = db.files(cls='real', groups='train') train = db.objects(cls='real', groups='train')
self.assertEqual(len(set(train.values())), 60) self.assertEqual(len(train), 60)
dev = db.files(cls='real', groups='devel') dev = db.objects(cls='real', groups='devel')
self.assertEqual(len(set(dev.values())), 60) self.assertEqual(len(dev), 60)
test = db.files(cls='real', groups='test') test = db.objects(cls='real', groups='test')
self.assertEqual(len(set(test.values())), 80) self.assertEqual(len(test), 80)
#tests train, devel and test files are distinct #tests train, devel and test files are distinct
s = set(train.values() + dev.values() + test.values()) s = set(train + dev + test)
self.assertEqual(len(s), 200) self.assertEqual(len(s), 200)
def queryAttackType(self, protocol, N): def queryAttackType(self, protocol, N):
db = Database() db = Database()
f = db.files(cls='attack', protocol=protocol) f = db.objects(cls='attack', protocol=protocol)
self.assertEqual(len(set(f.values())), N) self.assertEqual(len(f), N)
for k,v in f.items(): for k in f[:10]: #only the 10 first...
self.assertTrue(v.find('attack') != -1) k.get_attack()
self.assertRaises(RuntimeError, k.get_realaccess)
train = db.files(cls='attack', groups='train', protocol=protocol) train = db.objects(cls='attack', groups='train', protocol=protocol)
self.assertEqual(len(set(train.values())), int(round(0.3*N))) self.assertEqual(len(train), int(round(0.3*N)))
dev = db.files(cls='attack', groups='devel', protocol=protocol) dev = db.objects(cls='attack', groups='devel', protocol=protocol)
self.assertEqual(len(set(dev.values())), int(round(0.3*N))) self.assertEqual(len(dev), int(round(0.3*N)))
test = db.files(cls='attack', groups='test', protocol=protocol) test = db.objects(cls='attack', groups='test', protocol=protocol)
self.assertEqual(len(set(test.values())), int(round(0.4*N))) self.assertEqual(len(test), int(round(0.4*N)))
#tests train, devel and test files are distinct #tests train, devel and test files are distinct
s = set(train.values() + dev.values() + test.values()) s = set(train + dev + test)
self.assertEqual(len(s), N) self.assertEqual(len(s), N)
def test02_queryAttacks(self): def test02_queryAttacks(self):
...@@ -99,91 +101,57 @@ class ReplayDatabaseTest(unittest.TestCase): ...@@ -99,91 +101,57 @@ class ReplayDatabaseTest(unittest.TestCase):
def test08_queryEnrollments(self): def test08_queryEnrollments(self):
db = Database() db = Database()
f = db.files(cls='enroll') f = db.objects(cls='enroll')
self.assertEqual(len(set(f.values())), 100) #50 clients, 2 conditions self.assertEqual(len(f), 100) #50 clients, 2 conditions
for k,v in f.items(): for v in f:
self.assertTrue(v.find('enroll') != -1) self.assertEqual(v.get_realaccess().purpose, u'enroll')
def test08a_queryClients(self): def test09_queryClients(self):
db = Database() db = Database()
f = db.clients() f = db.clients()
self.assertEqual(len(f), 50) #50 clients self.assertEqual(len(f), 50) #50 clients
self.assertTrue(db.has_client(3)) self.assertTrue(db.has_client_id(3))
self.assertFalse(db.has_client(0)) self.assertFalse(db.has_client_id(0))
self.assertTrue(db.has_client(21)) self.assertTrue(db.has_client_id(21))
self.assertFalse(db.has_client(32)) self.assertFalse(db.has_client_id(32))
self.assertFalse(db.has_client(100)) self.assertFalse(db.has_client_id(100))
self.assertTrue(db.has_client(101)) self.assertTrue(db.has_client_id(101))
self.assertTrue(db.has_client(119)) self.assertTrue(db.has_client_id(119))
self.assertFalse(db.has_client(120)) self.assertFalse(db.has_client_id(120))
def test09_manage_files(self): def test10_queryfacefile(self):
db = Database()
o = db.objects(clients=(1,))[0]
o.facefile()
def test11_manage_files(self):
from bob.db.script.dbmanage import main from bob.db.script.dbmanage import main
self.assertEqual(main('replay files'.split()), 0) self.assertEqual(main('replay files'.split()), 0)
def test10_manage_dumplist_1(self): def test12_manage_dumplist_1(self):
from bob.db.script.dbmanage import main from bob.db.script.dbmanage import main
self.assertEqual(main('replay dumplist --self-test'.split()), 0) self.assertEqual(main('replay dumplist --self-test'.split()), 0)
def test11_manage_dumplist_2(self): def test13_manage_dumplist_2(self):
from bob.db.script.dbmanage import main from bob.db.script.dbmanage import main
self.assertEqual(main('replay dumplist --class=attack --group=devel --support=hand --protocol=highdef --self-test'.split()), 0) self.assertEqual(main('replay dumplist --class=attack --group=devel --support=hand --protocol=highdef --self-test'.split()), 0)
def test12_manage_dumplist_client(self): def test14_manage_dumplist_client(self):
from bob.db.script.dbmanage import main from bob.db.script.dbmanage import main
self.assertEqual(main('replay dumplist --client=117 --self-test'.split()), 0) self.assertEqual(main('replay dumplist --client=117 --self-test'.split()), 0)
def test13_manage_checkfiles(self): def test15_manage_checkfiles(self):
from bob.db.script.dbmanage import main from bob.db.script.dbmanage import main
self.assertEqual(main('replay checkfiles --self-test'.split()), 0) self.assertEqual(main('replay checkfiles --self-test'.split()), 0)
def test14_queryfacefile(self):
db = Database()
self.assertEqual(db.facefiles(('foo',), directory = 'dir')[0], 'dir/foo.face',)
def test15_queryfacefile_key(self):
db = Database()
self.assertEqual(db.facefiles_ids(ids=(1,), directory='dir'), db.paths(ids=(1,), prefix='dir', suffix='.face'))
def test16_queryInfo(self):
db = Database()
res = db.info((1,))
self.assertEqual(len(res), 1)
res = db.info((1,2))
self.assertEqual(len(res), 2)
res = db.info(db.reverse(('devel/attack/fixed/attack_highdef_client030_session01_highdef_photo_adverse',)))
self.assertEqual(len(res), 1)
res = res[0]
self.assertFalse(res['real'])
self.assertEqual(res['sample_device'], u'highdef')
self.assertEqual(res['group'], u'devel')
self.assertEqual(res['light'], u'adverse')
self.assertEqual(res['client'], 30)
self.assertEqual(res['attack_support'], u'fixed')
self.assertEqual(res['sample_type'], u'photo')
self.assertEqual(res['attack_device'], u'highdef')
res = db.info(db.reverse(('train/real/client001_session01_webcam_authenticate_adverse_1',)))
self.assertEqual(len(res), 1)
res = res[0]
self.assertTrue(res['real'])
self.assertEqual(res['group'], u'train')
self.assertEqual(res['light'], u'adverse')
self.assertEqual(res['client'], 1)
self.assertEqual(res['take'], 1)
self.assertEqual(res['purpose'], u'authenticate')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment