From 04d02d4fdd676a98345f20369a1c1fc7dc1c62f7 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Mon, 17 Sep 2012 19:50:10 +0200 Subject: [PATCH] Adds possibility to query for specific client identifiers --- xbob/db/replay/checkfiles.py | 4 ++ xbob/db/replay/dumplist.py | 4 ++ xbob/db/replay/query.py | 108 ++++++++++++++++++++++++++--------- xbob/db/replay/test.py | 12 +++- 4 files changed, 97 insertions(+), 31 deletions(-) diff --git a/xbob/db/replay/checkfiles.py b/xbob/db/replay/checkfiles.py index a48bfc7..ba07274 100644 --- a/xbob/db/replay/checkfiles.py +++ b/xbob/db/replay/checkfiles.py @@ -26,6 +26,7 @@ def checkfiles(args): groups=args.group, cls=args.cls, light=args.light, + clients=args.client, ) # go through all files, check if they are available on the filesystem @@ -62,8 +63,10 @@ def add_command(subparsers): if not db.is_valid(): protocols = ('waiting','for','database','creation') + clients = tuple() else: protocols = db.protocols() + clients = db.clients() parser.add_argument('-d', '--directory', dest="directory", default='', help="if given, this path will be prepended to every entry checked (defaults to '%(default)s')") parser.add_argument('-e', '--extension', dest="extension", default='', help="if given, this extension will be appended to every entry checked (defaults to '%(default)s')") @@ -72,6 +75,7 @@ def add_command(subparsers): parser.add_argument('-s', '--support', dest="support", default='', help="if given, this value will limit the check to those files using this type of attack support. (defaults to '%(default)s')", choices=db.attack_supports()) parser.add_argument('-x', '--protocol', dest="protocol", default='', help="if given, this value will limit the check to those files for a given protocol. (defaults to '%(default)s')", choices=protocols) parser.add_argument('-l', '--light', dest="light", default='', help="if given, this value will limit the check to those files shot under a given lighting. (defaults to '%(default)s')", choices=db.lights()) + parser.add_argument('-C', '--client', dest="client", default=None, type=int, help="if given, limits the dump to a particular client (defaults to '%(default)s')", choices=clients) parser.add_argument('--self-test', dest="selftest", default=False, action='store_true', help=SUPPRESS) diff --git a/xbob/db/replay/dumplist.py b/xbob/db/replay/dumplist.py index 3cbe51b..34f9d52 100644 --- a/xbob/db/replay/dumplist.py +++ b/xbob/db/replay/dumplist.py @@ -26,6 +26,7 @@ def dumplist(args): groups=args.group, cls=args.cls, light=args.light, + clients=args.client, ) output = sys.stdout @@ -51,8 +52,10 @@ def add_command(subparsers): if not db.is_valid(): protocols = ('waiting','for','database','creation') + clients = tuple() else: protocols = db.protocols() + clients = db.clients() parser.add_argument('-d', '--directory', dest="directory", default='', help="if given, this path will be prepended to every entry returned (defaults to '%(default)s')") parser.add_argument('-e', '--extension', dest="extension", default='', help="if given, this extension will be appended to every entry returned (defaults to '%(default)s')") @@ -61,6 +64,7 @@ def add_command(subparsers): parser.add_argument('-s', '--support', dest="support", default='', help="if given, this value will limit the output files to those using this type of attack support. (defaults to '%(default)s')", choices=db.attack_supports()) parser.add_argument('-x', '--protocol', dest="protocol", default='', help="if given, this value will limit the output files to those for a given protocol. (defaults to '%(default)s')", choices=protocols) parser.add_argument('-l', '--light', dest="light", default='', help="if given, this value will limit the output files to those shot under a given lighting. (defaults to '%(default)s')", choices=db.lights()) + parser.add_argument('-C', '--client', dest="client", default=None, type=int, help="if given, limits the dump to a particular client (defaults to '%(default)s')", choices=clients) parser.add_argument('--self-test', dest="selftest", default=False, action='store_true', help=SUPPRESS) diff --git a/xbob/db/replay/query.py b/xbob/db/replay/query.py index 20ac027..a690930 100644 --- a/xbob/db/replay/query.py +++ b/xbob/db/replay/query.py @@ -46,7 +46,8 @@ class Database(object): protocol='grandtest', groups=Client.set_choices, cls=('attack', 'real'), - light=File.light_choices): + light=File.light_choices, + clients=None): """Returns a set of filenames for the specific query by the user. Keyword Parameters: @@ -82,6 +83,11 @@ class Database(object): One of the lighting conditions as returned by lights() or a combination of the two (in a tuple), which is also the default. + clients + If set, should be a single integer or a list of integers that define the + client identifiers from which files should be retrieved. If ommited, set + to None or an empty list, then data from all clients is retrieved. + Returns: A dictionary containing the resolved filenames considering all the filtering criteria. The keys of the dictionary are unique identities for each file in the replay attack database. Conserve these numbers if you @@ -94,7 +100,8 @@ class Database(object): def check_validity(l, obj, valid, default): """Checks validity of user input data against a set of valid values""" if not l: return default - elif isinstance(l, str): return check_validity((l,), obj, valid, default) + elif not isinstance(l, (tuple, list)): + return check_validity((l,), obj, valid, default) for k in l: if k not in valid: raise RuntimeError, 'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid) @@ -124,6 +131,12 @@ class Database(object): raise RuntimeError, 'Invalid protocol "%s". Valid values are %s' % \ (protocol, VALID_PROTOCOLS) + # checks client identity validity + VALID_CLIENTS = self.clients() + clients = check_validity(clients, "client", VALID_CLIENTS, None) + + if clients is None: clients = VALID_CLIENTS + # resolve protocol object protocol = self.protocol(protocol) @@ -136,39 +149,42 @@ class Database(object): # real-accesses are simpler to query if 'enroll' in cls: - q = self.session.query(RealAccess).with_lockmode('read').join(File).join(Client).filter(Client.set.in_(groups)).filter(RealAccess.purpose=='enroll').filter(File.light.in_(light)).order_by(Client.id) + q = self.session.query(RealAccess).join(File).join(Client).filter(Client.set.in_(groups)).filter(Client.id.in_(clients)).filter(RealAccess.purpose=='enroll').filter(File.light.in_(light)).order_by(Client.id) for key, value in [(k.file.id, k.file.path) for k in q]: retval[key] = make_path(str(value), directory, extension) # real-accesses are simpler to query if 'real' in cls: - q = self.session.query(RealAccess).with_lockmode('read').join(File).join(Client).filter(RealAccess.protocols.contains(protocol)).filter(Client.set.in_(groups)).filter(File.light.in_(light)).order_by(Client.id) + q = self.session.query(RealAccess).join(File).join(Client).filter(RealAccess.protocols.contains(protocol)).filter(Client.id.in_(clients)).filter(Client.set.in_(groups)).filter(File.light.in_(light)).order_by(Client.id) for key, value in [(k.file.id, k.file.path) for k in q]: retval[key] = make_path(str(value), directory, extension) # attacks will have to be filtered a little bit more if 'attack' in cls: - q = self.session.query(Attack).with_lockmode('read').join(File).join(Client).filter(Attack.protocols.contains(protocol)).filter(Client.set.in_(groups)).filter(Attack.attack_support.in_(support)).filter(File.light.in_(light)).order_by(Client.id) + q = self.session.query(Attack).join(File).join(Client).filter(Attack.protocols.contains(protocol)).filter(Client.id.in_(clients)).filter(Client.set.in_(groups)).filter(Attack.attack_support.in_(support)).filter(File.light.in_(light)).order_by(Client.id) for key, value in [(k.file.id, k.file.path) for k in q]: retval[key] = make_path(str(value), directory, extension) return retval - def facefiles(self, filenames, directory=None): - """Queries the files containing the face locations for the frames in the videos specified by the input parameter filenames + """Queries the files containing the face locations for the frames in the + videos specified by the input parameter filenames Keyword parameters: filenames - The filenames of the videos. This object should be a python iterable (such as a tuple or list). + The filenames of the videos. This object should be a python iterable + (such as a tuple or list). directory - A directory name that will be prepended to the final filepaths returned. The face locations files should be located in this directory + A directory name that will be prepended to the final filepaths returned. + The face locations files should be located in this directory Returns: - A list of filenames with face locations. The face location files contain the following information, tab delimited: + A list of filenames with face locations. The face location files contain + the following information, space delimited: * Frame number * Bounding box top-left X coordinate @@ -176,21 +192,29 @@ class Database(object): * Bounding box width * Bounding box height - There is one row for each frame, and not all the frames contain detected faces + There is one row for each frame, and not all the frames contain detected + faces """ - if directory: return [os.path.join(directory, stem + '.face') for stem in filenames] + + if directory: + return [os.path.join(directory, stem + '.face') for stem in filenames] return [stem + '.face' for stem in filenames] def facebbx(self, filenames, directory=None): - """Reads the files containing the face locations for the frames in the videos specified by the input parameter filenames + """Reads the files containing the face locations for the frames in the + videos specified by the input parameter filenames Keyword parameters: filenames - The filenames of the videos. This object should be a python iterable (such as a tuple or list). + The filenames of the videos. This object should be a python iterable + (such as a tuple or list). Returns: - A list of numpy.ndarrays containing information about the locatied faces in the videos. Each element in the list corresponds to one input filename. Each row of the numpy.ndarray corresponds for one frame. The five columns of the numpy.ndarray denote: + A list of numpy.ndarrays containing information about the locatied faces + in the videos. Each element in the list corresponds to one input + filename. Each row of the numpy.ndarray corresponds for one frame. The + five columns of the numpy.ndarray denote: * Frame number * Bounding box top-left X coordinate @@ -200,6 +224,7 @@ class Database(object): Note that not all the frames contain detected faces. """ + facefiles = self.facefiles(filenames, directory) facesbbx = [] for facef in facefiles: @@ -214,34 +239,44 @@ class Database(object): return facesbbx def facefiles_ids(self, ids, directory=None): - """Queries the files containing the face locations for the frames in the videos specified by the input parameter ids + """Queries the files containing the face locations for the frames in the + videos specified by the input parameter ids Keyword parameters: ids - The ids of the objects in the database table "file". This object should be a python iterable (such as a tuple or list). + The ids of the objects in the database table "file". This object should + be a python iterable (such as a tuple or list). directory - A directory name that will be prepended to the final filepath returned. The face locations files should be located in this directory + A directory name that will be prepended to the final filepath returned. + The face locations files should be located in this directory Returns: - A list of filenames with face locations. For description on the face locations file format, see the documentation for faces() + A list of filenames with face locations. For description on the face + locations file format, see the documentation for faces() """ + if not directory: directory = '' facespaths = self.paths(ids, prefix=directory, suffix='.face') return facespaths def facebbx_ids(self, ids, directory=None): - """Reads the files containing the face locations for the frames in the videos specified by the input parameter filenames + """Reads the files containing the face locations for the frames in the + videos specified by the input parameter filenames Keyword parameters: filenames - The filenames of the videos. This object should be a python iterable (such as a tuple or list). + The filenames of the videos. This object should be a python iterable + (such as a tuple or list). Returns: - A list of numpy.ndarrays containing information about the locatied faces in the videos. Each element in the list corresponds to one input filename. Each row of the numpy.ndarray corresponds for one frame. The five columns of the numpy.ndarray denote: + A list of numpy.ndarrays containing information about the locatied faces + in the videos. Each element in the list corresponds to one input + filename. Each row of the numpy.ndarray corresponds for one frame. The + five columns of the numpy.ndarray denote: * Frame number * Bounding box top-left X coordinate @@ -251,6 +286,7 @@ class Database(object): Note that not all the frames contain detected faces. """ + facefiles = self.facefiles_ids(ids, directory) facesbbx = [] for facef in facefiles: @@ -264,13 +300,29 @@ class Database(object): facesbbx.append(bbx) return facesbbx + def clients(self): + """Returns the integer identifiers for all known clients""" + + if not self.is_valid(): + raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + + return tuple([k.id for k in self.session.query(Client)]) + + def has_client(self, id): + """Returns True if we have a client with a certain integer identifier""" + + if not self.is_valid(): + raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) + + return self.session.query(Client).filter(Client.id==id).count() != 0 + def protocols(self): """Returns the names of all registered protocols""" if not self.is_valid(): raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) - return tuple([k.name for k in self.session.query(Protocol).with_lockmode('read')]) + return tuple([k.name for k in self.session.query(Protocol)]) def has_protocol(self, name): """Tells if a certain protocol is available""" @@ -278,7 +330,7 @@ class Database(object): if not self.is_valid(): raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) - return self.session.query(Protocol).with_lockmode('read').filter(Protocol.name==name).count() != 0 + return self.session.query(Protocol).filter(Protocol.name==name).count() != 0 def protocol(self, name): """Returns the protocol object in the database given a certain name. Raises @@ -287,7 +339,7 @@ class Database(object): if not self.is_valid(): raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) - return self.session.query(Protocol).with_lockmode('read').filter(Protocol.name==name).one() + return self.session.query(Protocol).filter(Protocol.name==name).one() def groups(self): """Returns the names of all registered groups""" @@ -337,7 +389,7 @@ class Database(object): if not self.is_valid(): raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) - fobj = self.session.query(File).with_lockmode('read').filter(File.id.in_(ids)) + fobj = self.session.query(File).filter(File.id.in_(ids)) retval = [] for p in ids: retval.extend([os.path.join(prefix, str(k.path) + suffix) @@ -359,7 +411,7 @@ class Database(object): if not self.is_valid(): raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE) - fobj = self.session.query(File).with_lockmode('read').filter(File.path.in_(paths)) + fobj = self.session.query(File).filter(File.path.in_(paths)) retval = [] for p in paths: retval.extend([k.id for k in fobj if k.path == p]) @@ -393,7 +445,7 @@ class Database(object): from bob.io import save - fobj = self.session.query(File).with_lockmode('read').filter_by(id=id).one() + fobj = self.session.query(File).filter_by(id=id).one() fullpath = os.path.join(directory, str(fobj.path) + extension) fulldir = os.path.dirname(fullpath) utils.makedirs_safe(fulldir) diff --git a/xbob/db/replay/test.py b/xbob/db/replay/test.py index 8e0e8ee..5296a5a 100644 --- a/xbob/db/replay/test.py +++ b/xbob/db/replay/test.py @@ -122,17 +122,23 @@ class ReplayDatabaseTest(unittest.TestCase): self.assertEqual(main('replay dumplist --class=attack --group=devel --support=hand --protocol=highdef --self-test'.split()), 0) - def test12_manage_checkfiles(self): + def test12_manage_dumplist_client(self): + + from bob.db.script.dbmanage import main + + self.assertEqual(main('replay dumplist --client=117 --self-test'.split()), 0) + + def test13_manage_checkfiles(self): from bob.db.script.dbmanage import main self.assertEqual(main('replay checkfiles --self-test'.split()), 0) - def test13_queryfacefile(self): + def test14_queryfacefile(self): db = Database() self.assertEqual(db.facefiles(('foo',), directory = 'dir')[0], 'dir/foo.face',) - def test14_queryfacefile_key(self): + def test15_queryfacefile_key(self): db = Database() self.assertEqual(db.facefiles_ids(ids=(1,), directory='dir'), db.paths(ids=(1,), prefix='dir', suffix='.face')) -- GitLab