From ad1160e114a8de0d96428ebdf0a775e3e7d48be7 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Tue, 18 Sep 2012 15:28:53 +0200 Subject: [PATCH] Simplified API to the Replay Attack Database --- xbob/db/replay/__init__.py | 3 +- xbob/db/replay/checkfiles.py | 24 +-- xbob/db/replay/dumplist.py | 12 +- xbob/db/replay/models.py | 174 ++++++++++++++++- xbob/db/replay/query.py | 359 +++++++++++------------------------ xbob/db/replay/test.py | 132 +++++-------- 6 files changed, 354 insertions(+), 350 deletions(-) diff --git a/xbob/db/replay/__init__.py b/xbob/db/replay/__init__.py index 13cb82f..eeb7d01 100644 --- a/xbob/db/replay/__init__.py +++ b/xbob/db/replay/__init__.py @@ -25,5 +25,6 @@ on your references: """ from .query import Database +from .models import Client, File, Protocol, RealAccess, Attack -__all__ = ['Database'] +__all__ = dir() diff --git a/xbob/db/replay/checkfiles.py b/xbob/db/replay/checkfiles.py index ba07274..9ab3027 100644 --- a/xbob/db/replay/checkfiles.py +++ b/xbob/db/replay/checkfiles.py @@ -18,9 +18,7 @@ def checkfiles(args): from .query import Database db = Database() - r = db.files( - directory=args.directory, - extension=args.extension, + r = db.objects( protocol=args.protocol, support=args.support, groups=args.group, @@ -30,11 +28,13 @@ def checkfiles(args): ) # go through all files, check if they are available on the filesystem - good = {} - bad = {} - for id, f in r.items(): - if os.path.exists(f): good[id] = f - else: bad[id] = f + good = [] + bad = [] + for f in r: + if os.path.exists(f.make_path(args.directory, args.extension)): + good.append(f) + else: + bad.append(f) # report output = sys.stdout @@ -43,8 +43,8 @@ def checkfiles(args): output = null() if bad: - for id, f in bad.items(): - output.write('Cannot find file "%s"\n' % (f,)) + for f in bad: + output.write('Cannot find file "%s"\n' % (f.make_path(args.directory, args.extension),)) output.write('%d files (out of %d) were not found at "%s"\n' % \ (len(bad), len(r), args.directory)) @@ -65,8 +65,8 @@ def add_command(subparsers): protocols = ('waiting','for','database','creation') clients = tuple() else: - protocols = db.protocols() - clients = db.clients() + protocols = [k.name for k in db.protos()] + clients = [k.id for k in db.clients()] parser.add_argument('-d', '--directory', dest="directory", default='', help="if given, this path will be prepended to every entry checked (defaults to '%(default)s')") parser.add_argument('-e', '--extension', dest="extension", default='', help="if given, this extension will be appended to every entry checked (defaults to '%(default)s')") diff --git a/xbob/db/replay/dumplist.py b/xbob/db/replay/dumplist.py index 34f9d52..85d9351 100644 --- a/xbob/db/replay/dumplist.py +++ b/xbob/db/replay/dumplist.py @@ -18,9 +18,7 @@ def dumplist(args): from .query import Database db = Database() - r = db.files( - directory=args.directory, - extension=args.extension, + r = db.objects( protocol=args.protocol, support=args.support, groups=args.group, @@ -34,8 +32,8 @@ def dumplist(args): from bob.db.utils import null output = null() - for id, f in r.items(): - output.write('%s\n' % (f,)) + for f in r: + output.write('%s\n' % (f.make_path(args.directory, args.extension),)) return 0 @@ -54,8 +52,8 @@ def add_command(subparsers): protocols = ('waiting','for','database','creation') clients = tuple() else: - protocols = db.protocols() - clients = db.clients() + protocols = [k.name for k in db.protos()] + clients = [k.id for k in db.clients()] parser.add_argument('-d', '--directory', dest="directory", default='', help="if given, this path will be prepended to every entry returned (defaults to '%(default)s')") parser.add_argument('-e', '--extension', dest="extension", default='', help="if given, this extension will be appended to every entry returned (defaults to '%(default)s')") diff --git a/xbob/db/replay/models.py b/xbob/db/replay/models.py index 6428d72..6618e70 100644 --- a/xbob/db/replay/models.py +++ b/xbob/db/replay/models.py @@ -6,40 +6,61 @@ """Table models and functionality for the Replay Attack DB. """ +import os from sqlalchemy import Table, Column, Integer, String, ForeignKey from bob.db.sqlalchemy_migration import Enum, relationship +import bob.db.utils from sqlalchemy.orm import backref from sqlalchemy.ext.declarative import declarative_base +import numpy Base = declarative_base() class Client(Base): + """Database clients, marked by an integer identifier and the set they belong + to""" + __tablename__ = 'client' set_choices = ('train', 'devel', 'test') - + """Possible groups to which clients may belong to""" + id = Column(Integer, primary_key=True) + """Key identifier for clients""" + set = Column(Enum(*set_choices)) + """Set to which this client belongs to""" def __init__(self, id, set): self.id = id self.set = set def __repr__(self): - return "<Client('%s', '%s')>" % (self.id, self.set) + return "Client('%s', '%s')" % (self.id, self.set) class File(Base): + """Generic file container""" + __tablename__ = 'file' light_choices = ('controlled', 'adverse') + """List of illumination conditions for data taking""" id = Column(Integer, primary_key=True) + """Key identifier for files""" + client_id = Column(Integer, ForeignKey('client.id')) # for SQL + """The client identifier to which this file is bound to""" + path = Column(String(100), unique=True) + """The (unique) path to this file inside the database""" + light = Column(Enum(*light_choices)) + """The illumination condition in which the data for this file was taken""" # for Python client = relationship(Client, backref=backref('files', order_by=id)) + """A direct link to the client object that this file belongs to""" def __init__(self, client, path, light): self.client = client @@ -47,7 +68,109 @@ class File(Base): self.light = light def __repr__(self): - print "<File('%s')>" % self.path + return "File('%s')" % self.path + + def make_path(self, directory=None, extension=None): + """Wraps the current path so that a complete path is formed + + Keyword parameters: + + directory + An optional directory name that will be prefixed to the returned result. + + extension + An optional extension that will be suffixed to the returned filename. The + extension normally includes the leading ``.`` character as in ``.jpg`` or + ``.hdf5``. + + Returns a string containing the newly generated file path. + """ + + if not directory: directory = '' + if not extension: extension = '' + + return os.path.join(directory, self.path + extension) + + def facefile(self, directory=None): + """Returns the path to the companion face bounding-box file + + Keyword parameters: + + directory + An optional directory name that will be prefixed to the returned result. + + Returns a string containing the face file path. + """ + + if not directory: directory = '' + directory = os.path.join(directory, 'face-locations') + return self.make_path(directory, '.face') + + def bbx(self, directory=None): + """Reads the file containing the face locations for the frames in the + current video + + Keyword parameters: + + directory + A directory name that will be prepended to the final filepaths where the + face bounding boxes are located, if not on the current directory. + + Returns: + A :py:class:`numpy.ndarray` containing information about the located + faces in the videos. Each row of the :py:class:`numpy.ndarray` + corresponds for one frame. The five columns of the + :py:class:`numpy.ndarray` are (all integers): + + * Frame number (int) + * Bounding box top-left X coordinate (int) + * Bounding box top-left Y coordinate (int) + * Bounding box width (int) + * Bounding box height (int) + + Note that **not** all the frames may contain detected faces. + """ + + return numpy.loadtxt(self.facefile(directory), dtype=int) + + def is_real(self): + """Returns True if this file belongs to a real access, False otherwise""" + + return bool(self.realaccess) + + def get_realaccess(self): + """Returns the real-access object equivalent to this file or raise""" + if len(self.realaccess) == 0: + raise RuntimeError, "%s is not a real-access" % self + return self.realaccess[0] + + def get_attack(self): + """Returns the attack object equivalent to this file or raise""" + if len(self.attack) == 0: + raise RuntimeError, "%s is not an attack" % self + return self.attack[0] + + def save(self, data, directory=None, extension='.hdf5'): + """Saves the input data at the specified location and using the given + extension. + + Keyword parameters: + + data + The data blob to be saved (normally a :py:class:`numpy.ndarray`). + + directory + If not empty or None, this directory is prefixed to the final file + destination + + extension + The extension of the filename - this will control the type of output and + the codec for saving the input blob. + """ + + path = self.make_path(directory, extension) + bob.utils.makedirs_safe(os.path.dirname(path)) + bob.io.save(data, path) # Intermediate mapping from RealAccess's to Protocol's realaccesses_protocols = Table('realaccesses_protocols', Base.metadata, @@ -62,31 +185,49 @@ attacks_protocols = Table('attacks_protocols', Base.metadata, ) class Protocol(Base): + """Replay attack protocol""" + __tablename__ = 'protocol' id = Column(Integer, primary_key=True) + """Unique identifier for the protocol (integer)""" + name = Column(String(20), unique=True) + """Protocol name""" def __init__(self, name): self.name = name def __repr__(self): - return "<Protocol('%s')>" % (self.name,) + return "Protocol('%s')" % (self.name,) class RealAccess(Base): + """Defines Real-Accesses (licit attempts to authenticate)""" + __tablename__ = 'realaccess' purpose_choices = ('authenticate', 'enroll') + """Types of purpose for this video""" id = Column(Integer, primary_key=True) + """Unique identifier for this real-access object""" + file_id = Column(Integer, ForeignKey('file.id')) # for SQL + """The file identifier the current real-access is bound to""" + purpose = Column(Enum(*purpose_choices)) + """The purpose of this video""" + take = Column(Integer) + """Take number""" # for Python file = relationship(File, backref=backref('realaccess', order_by=id)) + """A direct link to the :py:class:`.File` object this real-access belongs to""" + protocols = relationship("Protocol", secondary=realaccesses_protocols, backref='realaccesses') + """A direct link to the protocols this file is linked to""" def __init__(self, file, purpose, take): self.file = file @@ -94,27 +235,50 @@ class RealAccess(Base): self.take = take def __repr__(self): - return "<RealAccess('%s')>" % (self.file.path) + return "RealAccess('%s')" % (self.file.path) class Attack(Base): + """Defines Spoofing Attacks (illicit attempts to authenticate)""" + __tablename__ = 'attack' attack_support_choices = ('fixed', 'hand') + """Types of attack support""" + attack_device_choices = ('print', 'mobile', 'highdef', 'mask') + """Types of devices used for spoofing""" + sample_type_choices = ('video', 'photo') + """Original sample type""" + sample_device_choices = ('mobile', 'highdef') + """Original sample device""" id = Column(Integer, primary_key=True) + """Unique identifier for this attack""" + file_id = Column(Integer, ForeignKey('file.id')) # for SQL + """The file identifier this attack is linked to""" + attack_support = Column(Enum(*attack_support_choices)) + """The attack support""" + attack_device = Column(Enum(*attack_device_choices)) + """The attack device""" + sample_type = Column(Enum(*sample_type_choices)) + """The attack sample type""" + sample_device = Column(Enum(*sample_device_choices)) + """The attack sample device""" # for Python file = relationship(File, backref=backref('attack', order_by=id)) + """A direct link to the :py:class:`.File` object bound to this attack""" + protocols = relationship("Protocol", secondary=attacks_protocols, backref='attacks') + """A direct link to the protocols this file is linked to""" def __init__(self, file, attack_support, attack_device, sample_type, sample_device): self.file = file diff --git a/xbob/db/replay/query.py b/xbob/db/replay/query.py index 46b75b8..228fe5c 100644 --- a/xbob/db/replay/query.py +++ b/xbob/db/replay/query.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # vim: set fileencoding=utf-8 : # Andre Anjos <andre.dos.anjos@gmail.com> -# Tue 17 May 13:58:09 2011 +# Tue 17 May 13:58:09 2011 """This module provides the Dataset interface allowing the user to query the replay attack database in the most obvious ways. @@ -27,7 +27,7 @@ class Database(object): def __init__(self): # opens a session to the database - keep it open until the end self.connect() - + def connect(self): """Tries connecting or re-connecting to the database""" if not os.path.exists(SQLITE_FILE): @@ -41,22 +41,19 @@ class Database(object): return self.session is not None - def files(self, directory=None, extension=None, - support=Attack.attack_support_choices, - protocol='grandtest', - groups=Client.set_choices, - cls=('attack', 'real'), - light=File.light_choices, - clients=None): - """Returns a set of filenames for the specific query by the user. + def assert_validity(self): + """Raise a RuntimeError if the database backend is not available""" - Keyword Parameters: + if not self.is_valid(): + raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) - directory - A directory name that will be prepended to the final filepath returned + def objects(self, support=Attack.attack_support_choices, + protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'), + light=File.light_choices, clients=None): + """Returns a list of unique :py:class:`.File` objects for the specific + query by the user. - extension - A filename extension that will be appended to the final filepath returned + Keyword parameters: support One of the valid support types as returned by attack_supports() or all, @@ -88,37 +85,28 @@ class Database(object): client identifiers from which files should be retrieved. If ommited, set to None or an empty list, then data from all clients is retrieved. - Returns: A dictionary containing the resolved filenames considering all - the filtering criteria. The keys of the dictionary are unique identities - for each file in the replay attack database. Conserve these numbers if you - wish to save processing results later on. + Returns: A list of :py:class:`.File` objects. """ - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + self.assert_validity() def check_validity(l, obj, valid, default): """Checks validity of user input data against a set of valid values""" if not l: return default - elif not isinstance(l, (tuple, list)): + elif not isinstance(l, (tuple, list)): return check_validity((l,), obj, valid, default) for k in l: if k not in valid: raise RuntimeError, 'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid) return l - def make_path(stem, directory, extension): - if not extension: extension = '' - if directory: return os.path.join(directory, stem + extension) - return stem + extension - # check if groups set are valid VALID_GROUPS = self.groups() - groups = check_validity(groups, "group", VALID_GROUPS, VALID_GROUPS) + groups = check_validity(groups, "group", VALID_GROUPS, None) # check if supports set are valid VALID_SUPPORTS = self.attack_supports() - support = check_validity(support, "support", VALID_SUPPORTS, VALID_SUPPORTS) + support = check_validity(support, "support", VALID_SUPPORTS, None) # by default, do NOT grab enrollment data from the database VALID_CLASSES = ('real', 'attack', 'enroll') @@ -126,292 +114,169 @@ class Database(object): # check protocol validity if not protocol: protocol = 'grandtest' #default - VALID_PROTOCOLS = self.protocols() + VALID_PROTOCOLS = [k.name for k in self.protos()] if protocol not in VALID_PROTOCOLS: raise RuntimeError, 'Invalid protocol "%s". Valid values are %s' % \ (protocol, VALID_PROTOCOLS) # checks client identity validity - VALID_CLIENTS = self.clients() + VALID_CLIENTS = [k.id for k in self.clients()] clients = check_validity(clients, "client", VALID_CLIENTS, None) - if clients is None: clients = VALID_CLIENTS - # resolve protocol object protocol = self.protocol(protocol) # checks if the light is valid VALID_LIGHTS = self.lights() - light = check_validity(light, "light", VALID_LIGHTS, VALID_LIGHTS) + light = check_validity(light, "light", VALID_LIGHTS, None) # now query the database - retval = {} + retval = [] # real-accesses are simpler to query if 'enroll' in cls: - q = self.session.query(RealAccess).join(File).join(Client).filter(Client.set.in_(groups)).filter(Client.id.in_(clients)).filter(RealAccess.purpose=='enroll').filter(File.light.in_(light)).order_by(Client.id) - for key, value in [(k.file.id, k.file.path) for k in q]: - retval[key] = make_path(str(value), directory, extension) - + q = self.session.query(File).join(RealAccess).join(Client) + if groups: q = q.filter(Client.set.in_(groups)) + if clients: q = q.filter(Client.id.in_(clients)) + if light: q = q.filter(File.light.in_(light)) + q = q.filter(RealAccess.purpose=='enroll') + q = q.order_by(Client.id) + retval += list(q) + # real-accesses are simpler to query if 'real' in cls: - q = self.session.query(RealAccess).join(File).join(Client).filter(RealAccess.protocols.contains(protocol)).filter(Client.id.in_(clients)).filter(Client.set.in_(groups)).filter(File.light.in_(light)).order_by(Client.id) - for key, value in [(k.file.id, k.file.path) for k in q]: - retval[key] = make_path(str(value), directory, extension) + q = self.session.query(File).join(RealAccess).join(Client) + if groups: q = q.filter(Client.set.in_(groups)) + if clients: q = q.filter(Client.id.in_(clients)) + if light: q = q.filter(File.light.in_(light)) + q = q.filter(RealAccess.protocols.contains(protocol)) + q = q.order_by(Client.id) + retval += list(q) # attacks will have to be filtered a little bit more if 'attack' in cls: - q = self.session.query(Attack).join(File).join(Client).filter(Attack.protocols.contains(protocol)).filter(Client.id.in_(clients)).filter(Client.set.in_(groups)).filter(Attack.attack_support.in_(support)).filter(File.light.in_(light)).order_by(Client.id) - - for key, value in [(k.file.id, k.file.path) for k in q]: - retval[key] = make_path(str(value), directory, extension) + q = self.session.query(File).join(Attack).join(Client) + if groups: q = q.filter(Client.set.in_(groups)) + if clients: q = q.filter(Client.id.in_(clients)) + if support: q = q.filter(Attack.attack_support.in_(support)) + if light: q = q.filter(File.light.in_(light)) + q = q.filter(Attack.protocols.contains(protocol)) + q = q.order_by(Client.id) + retval += list(q) return retval - def facefiles(self, filenames, directory=None): - """Queries the files containing the face locations for the frames in the - videos specified by the input parameter filenames - - Keyword parameters: - - filenames - The filenames of the videos. This object should be a python iterable - (such as a tuple or list). - - directory - A directory name that will be prepended to the final filepaths returned. - The face locations files should be located in this directory - - Returns: - A list of filenames with face locations. The face location files contain - the following information, space delimited: - - * Frame number - * Bounding box top-left X coordinate - * Bounding box top-left Y coordinate - * Bounding box width - * Bounding box height - - There is one row for each frame, and not all the frames contain detected - faces - """ + def files(self, directory=None, extension=None, **object_query): + """Returns a set of filenames for the specific query by the user. - if directory: - return [os.path.join(directory, stem + '.face') for stem in filenames] - return [stem + '.face' for stem in filenames] + .. deprecated:: 1.1.0 - def facebbx(self, filenames, directory=None): - """Reads the files containing the face locations for the frames in the - videos specified by the input parameter filenames + This function is *deprecated*, use :py:meth:`.Database.objects` instead. - Keyword parameters: - - filenames - The filenames of the videos. This object should be a python iterable - (such as a tuple or list). - - Returns: - A list of numpy.ndarrays containing information about the locatied faces - in the videos. Each element in the list corresponds to one input - filename. Each row of the numpy.ndarray corresponds for one frame. The - five columns of the numpy.ndarray denote: - - * Frame number - * Bounding box top-left X coordinate - * Bounding box top-left Y coordinate - * Bounding box width - * Bounding box height - - Note that not all the frames contain detected faces. - """ - - facefiles = self.facefiles(filenames, directory) - facesbbx = [] - for facef in facefiles: - lines = open(facef, "r").readlines() - bbx = numpy.ndarray((len(lines), 5), dtype='int') - lc = 0 - for l in lines: - words = l.split() - bbx[lc] = [int(w) for w in words] - lc+=1 - facesbbx.append(bbx) - return facesbbx - - def facefiles_ids(self, ids, directory=None): - """Queries the files containing the face locations for the frames in the - videos specified by the input parameter ids + Keyword Parameters: - Keyword parameters: - - ids - The ids of the objects in the database table "file". This object should - be a python iterable (such as a tuple or list). - - directory - A directory name that will be prepended to the final filepath returned. - The face locations files should be located in this directory - - Returns: - A list of filenames with face locations. For description on the face - locations file format, see the documentation for faces() - """ + directory + A directory name that will be prepended to the final filepath returned - if not directory: - directory = '' - facespaths = self.paths(ids, prefix=directory, suffix='.face') - return facespaths + extension + A filename extension that will be appended to the final filepath returned - def facebbx_ids(self, ids, directory=None): - """Reads the files containing the face locations for the frames in the - videos specified by the input parameter filenames + object_query + All remaining arguments are passed to :py:meth:`.Database.objects` + untouched. Please check the documentation for such method for more + details. - Keyword parameters: - - filenames - The filenames of the videos. This object should be a python iterable - (such as a tuple or list). - - Returns: - A list of numpy.ndarrays containing information about the locatied faces - in the videos. Each element in the list corresponds to one input - filename. Each row of the numpy.ndarray corresponds for one frame. The - five columns of the numpy.ndarray denote: - - * Frame number - * Bounding box top-left X coordinate - * Bounding box top-left Y coordinate - * Bounding box width - * Bounding box height - - Note that not all the frames contain detected faces. + Returns: A dictionary containing the resolved filenames considering all + the filtering criteria. The keys of the dictionary are unique identities + for each file in the replay attack database. Conserve these numbers if you + wish to save processing results later on. """ - facefiles = self.facefiles_ids(ids, directory) - facesbbx = [] - for facef in facefiles: - lines = open(facef, "r").readlines() - bbx = numpy.ndarray((len(lines), 5), dtype='int') - lc = 0 - for l in lines: - words = l.split() - bbx[lc] = [int(w) for w in words] - lc+=1 - facesbbx.append(bbx) - return facesbbx + import warnings + warnings.warn("The method Database.files() is deprecated, use Database.objects() for more powerful object retrieval", DeprecationWarning) - def clients(self): - """Returns the integer identifiers for all known clients""" + return dict([(k.id, k.make_path(directory, extension)) for k in self.objects(**object_query)]) - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + def clients(self): + """Returns an iterable with all known clients""" - return tuple([k.id for k in self.session.query(Client)]) + self.assert_validity() + return list(self.session.query(Client)) - def has_client(self, id): + def has_client_id(self, id): """Returns True if we have a client with a certain integer identifier""" - - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + self.assert_validity() return self.session.query(Client).filter(Client.id==id).count() != 0 def protocols(self): - """Returns the names of all registered protocols""" + """Returns the names of all registered protocols - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + .. deprecated:: 1.1.0 + + This function is *deprecated*, use :py:meth:`.Database.protos` instead. + + """ + import warnings + warnings.warn("The method Database.protocols() is deprecated, use Database.protos() for more powerful object retrieval", DeprecationWarning) + + self.assert_validity() return tuple([k.name for k in self.session.query(Protocol)]) + def protos(self): + """Returns all registered protocols""" + + self.assert_validity() + return list(self.session.query(Protocol)) + def has_protocol(self, name): """Tells if a certain protocol is available""" - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) - + self.assert_validity() return self.session.query(Protocol).filter(Protocol.name==name).count() != 0 def protocol(self, name): """Returns the protocol object in the database given a certain name. Raises an error if that does not exist.""" - - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + self.assert_validity() return self.session.query(Protocol).filter(Protocol.name==name).one() def groups(self): """Returns the names of all registered groups""" + return Client.set_choices def lights(self): """Returns light variations available in the database""" + return File.light_choices def attack_supports(self): """Returns attack supports available in the database""" + return Attack.attack_support_choices def attack_devices(self): """Returns attack devices available in the database""" + return Attack.attack_device_choices def attack_sampling_devices(self): """Returns sampling devices available in the database""" + return Attack.sample_device_choices def attack_sample_types(self): """Returns attack sample types available in the database""" - return Attack.sample_type_choices - - def info(self, ids): - """Returns a dictionary of information for each input id - - Keyword Parameters: - - id - The ids of the object in the database table "file". This object should be - a python iterable (such as a tuple or list). - - Returns a list (that may be empty) of dictionaries containing each of the - identities properties. - """ - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) - - fobj = self.session.query(File).filter(File.id.in_(ids)) - retval = [] - for f in fobj: - insert = {} - insert['client'] = f.client.id - insert['group'] = f.client.set - insert['path'] = f.path - insert['light'] = f.light - - if f.attack: - insert['real'] = False - o = f.attack[0] - insert['attack_support'] = o.attack_support - insert['attack_device'] = o.attack_device - insert['sample_type'] = o.sample_type - insert['sample_device'] = o.sample_device - - else: #it's a real access - insert['real'] = True - o = f.realaccess[0] - insert['purpose'] = o.purpose - insert['take'] = o.take - - retval.append(insert) - - return retval + return Attack.sample_type_choices def paths(self, ids, prefix='', suffix=''): """Returns a full file paths considering particular file ids, a given directory and an extension - + Keyword Parameters: id @@ -429,19 +294,17 @@ class Database(object): file ids. """ - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + self.assert_validity() fobj = self.session.query(File).filter(File.id.in_(ids)) retval = [] for p in ids: - retval.extend([os.path.join(prefix, str(k.path) + suffix) - for k in fobj if k.id == p]) + retval.extend([k.make_path(prefix, suffix) for k in fobj if k.id == p]) return retval def reverse(self, paths): """Reverses the lookup: from certain stems, returning file ids - + Keyword Parameters: paths @@ -451,11 +314,9 @@ class Database(object): Returns a list (that may be empty). """ - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + self.assert_validity() fobj = self.session.query(File).filter(File.path.in_(paths)) - retval = [] for p in paths: retval.extend([k.id for k in fobj if k.path == p]) return retval @@ -463,9 +324,13 @@ class Database(object): def save_one(self, id, obj, directory, extension): """Saves a single object supporting the bob save() protocol. + .. deprecated:: 1.1.0 + + This function is *deprecated*, use :py:meth:`.File.save()` instead. + This method will call save() on the the given object using the correct database filename stem for the given id. - + Keyword Parameters: id @@ -483,12 +348,13 @@ class Database(object): The extension determines the way each of the arrays will be saved. """ - if not self.is_valid(): - raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + import warnings + warnings.warn("The method Database.save_one() is deprecated, use the File object directly as returned by Database.objects() for more powerful object manipulation.", DeprecationWarning) - from bob.io import save + self.assert_validity() fobj = self.session.query(File).filter_by(id=id).one() + fullpath = os.path.join(directory, str(fobj.path) + extension) fulldir = os.path.dirname(fullpath) utils.makedirs_safe(fulldir) @@ -499,6 +365,10 @@ class Database(object): and saves the data respecting the original arrangement as returned by files(). + .. deprecated:: 1.1.0 + + This function is *deprecated*, use :py:meth:`.File.save()` instead. + Keyword Parameters: data @@ -514,6 +384,9 @@ class Database(object): extension The extension determines the way each of the arrays will be saved. """ - + + import warnings + warnings.warn("The method Database.save() is deprecated, use the File object directly as returned by Database.objects() for more powerful object manipulation.", DeprecationWarning) + for key, value in data: self.save_one(key, value, directory, extension) diff --git a/xbob/db/replay/test.py b/xbob/db/replay/test.py index 89c5a3e..5853f29 100644 --- a/xbob/db/replay/test.py +++ b/xbob/db/replay/test.py @@ -23,6 +23,7 @@ import os, sys import unittest from .query import Database +from .models import * class ReplayDatabaseTest(unittest.TestCase): """Performs various tests on the replay attack database.""" @@ -30,46 +31,47 @@ class ReplayDatabaseTest(unittest.TestCase): def test01_queryRealAccesses(self): db = Database() - f = db.files(cls='real') - self.assertEqual(len(set(f.values())), 200) #200 unique auth sessions - for k,v in f.items(): - self.assertTrue( (v.find('authenticate') != -1) ) - self.assertTrue( (v.find('real') != -1) ) - self.assertTrue( (v.find('webcam') != -1) ) + f = db.objects(cls='real') + self.assertEqual(len(f), 200) #200 unique auth sessions + for v in f[:10]: #only the 10 first... + self.assertTrue(isinstance(v.get_realaccess(), RealAccess)) + o = v.get_realaccess() + self.assertEqual(o.purpose, u'authenticate') - train = db.files(cls='real', groups='train') - self.assertEqual(len(set(train.values())), 60) + train = db.objects(cls='real', groups='train') + self.assertEqual(len(train), 60) - dev = db.files(cls='real', groups='devel') - self.assertEqual(len(set(dev.values())), 60) + dev = db.objects(cls='real', groups='devel') + self.assertEqual(len(dev), 60) - test = db.files(cls='real', groups='test') - self.assertEqual(len(set(test.values())), 80) + test = db.objects(cls='real', groups='test') + self.assertEqual(len(test), 80) #tests train, devel and test files are distinct - s = set(train.values() + dev.values() + test.values()) + s = set(train + dev + test) self.assertEqual(len(s), 200) def queryAttackType(self, protocol, N): db = Database() - f = db.files(cls='attack', protocol=protocol) + f = db.objects(cls='attack', protocol=protocol) - self.assertEqual(len(set(f.values())), N) - for k,v in f.items(): - self.assertTrue(v.find('attack') != -1) + self.assertEqual(len(f), N) + for k in f[:10]: #only the 10 first... + k.get_attack() + self.assertRaises(RuntimeError, k.get_realaccess) - train = db.files(cls='attack', groups='train', protocol=protocol) - self.assertEqual(len(set(train.values())), int(round(0.3*N))) + train = db.objects(cls='attack', groups='train', protocol=protocol) + self.assertEqual(len(train), int(round(0.3*N))) - dev = db.files(cls='attack', groups='devel', protocol=protocol) - self.assertEqual(len(set(dev.values())), int(round(0.3*N))) + dev = db.objects(cls='attack', groups='devel', protocol=protocol) + self.assertEqual(len(dev), int(round(0.3*N))) - test = db.files(cls='attack', groups='test', protocol=protocol) - self.assertEqual(len(set(test.values())), int(round(0.4*N))) + test = db.objects(cls='attack', groups='test', protocol=protocol) + self.assertEqual(len(test), int(round(0.4*N))) #tests train, devel and test files are distinct - s = set(train.values() + dev.values() + test.values()) + s = set(train + dev + test) self.assertEqual(len(s), N) def test02_queryAttacks(self): @@ -99,91 +101,57 @@ class ReplayDatabaseTest(unittest.TestCase): def test08_queryEnrollments(self): db = Database() - f = db.files(cls='enroll') - self.assertEqual(len(set(f.values())), 100) #50 clients, 2 conditions - for k,v in f.items(): - self.assertTrue(v.find('enroll') != -1) + f = db.objects(cls='enroll') + self.assertEqual(len(f), 100) #50 clients, 2 conditions + for v in f: + self.assertEqual(v.get_realaccess().purpose, u'enroll') - def test08a_queryClients(self): + def test09_queryClients(self): db = Database() f = db.clients() self.assertEqual(len(f), 50) #50 clients - self.assertTrue(db.has_client(3)) - self.assertFalse(db.has_client(0)) - self.assertTrue(db.has_client(21)) - self.assertFalse(db.has_client(32)) - self.assertFalse(db.has_client(100)) - self.assertTrue(db.has_client(101)) - self.assertTrue(db.has_client(119)) - self.assertFalse(db.has_client(120)) + self.assertTrue(db.has_client_id(3)) + self.assertFalse(db.has_client_id(0)) + self.assertTrue(db.has_client_id(21)) + self.assertFalse(db.has_client_id(32)) + self.assertFalse(db.has_client_id(100)) + self.assertTrue(db.has_client_id(101)) + self.assertTrue(db.has_client_id(119)) + self.assertFalse(db.has_client_id(120)) - def test09_manage_files(self): + def test10_queryfacefile(self): + + db = Database() + o = db.objects(clients=(1,))[0] + o.facefile() + + def test11_manage_files(self): from bob.db.script.dbmanage import main self.assertEqual(main('replay files'.split()), 0) - def test10_manage_dumplist_1(self): + def test12_manage_dumplist_1(self): from bob.db.script.dbmanage import main self.assertEqual(main('replay dumplist --self-test'.split()), 0) - def test11_manage_dumplist_2(self): + def test13_manage_dumplist_2(self): from bob.db.script.dbmanage import main self.assertEqual(main('replay dumplist --class=attack --group=devel --support=hand --protocol=highdef --self-test'.split()), 0) - def test12_manage_dumplist_client(self): + def test14_manage_dumplist_client(self): from bob.db.script.dbmanage import main self.assertEqual(main('replay dumplist --client=117 --self-test'.split()), 0) - def test13_manage_checkfiles(self): + def test15_manage_checkfiles(self): from bob.db.script.dbmanage import main self.assertEqual(main('replay checkfiles --self-test'.split()), 0) - - def test14_queryfacefile(self): - - db = Database() - self.assertEqual(db.facefiles(('foo',), directory = 'dir')[0], 'dir/foo.face',) - - def test15_queryfacefile_key(self): - db = Database() - self.assertEqual(db.facefiles_ids(ids=(1,), directory='dir'), db.paths(ids=(1,), prefix='dir', suffix='.face')) - - def test16_queryInfo(self): - - db = Database() - res = db.info((1,)) - self.assertEqual(len(res), 1) - - res = db.info((1,2)) - self.assertEqual(len(res), 2) - - res = db.info(db.reverse(('devel/attack/fixed/attack_highdef_client030_session01_highdef_photo_adverse',))) - self.assertEqual(len(res), 1) - res = res[0] - self.assertFalse(res['real']) - self.assertEqual(res['sample_device'], u'highdef') - self.assertEqual(res['group'], u'devel') - self.assertEqual(res['light'], u'adverse') - self.assertEqual(res['client'], 30) - self.assertEqual(res['attack_support'], u'fixed') - self.assertEqual(res['sample_type'], u'photo') - self.assertEqual(res['attack_device'], u'highdef') - - res = db.info(db.reverse(('train/real/client001_session01_webcam_authenticate_adverse_1',))) - self.assertEqual(len(res), 1) - res = res[0] - self.assertTrue(res['real']) - self.assertEqual(res['group'], u'train') - self.assertEqual(res['light'], u'adverse') - self.assertEqual(res['client'], 1) - self.assertEqual(res['take'], 1) - self.assertEqual(res['purpose'], u'authenticate') -- GitLab