Newer
Older
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Andre Anjos <andre.dos.anjos@gmail.com>
# Tue 17 May 13:58:09 2011
"""This module provides the Dataset interface allowing the user to query the
replay attack database in the most obvious ways.
"""
import os
import logging
from bob.db import utils
from .models import *
from .driver import Interface
INFO = Interface()
SQLITE_FILE = INFO.files()[0]
class Database(object):
"""The dataset class opens and maintains a connection opened to the Database.
It provides many different ways to probe for the characteristics of the data
and for the data itself inside the database.
"""
def __init__(self):
# opens a session to the database - keep it open until the end
self.connect()
def connect(self):
"""Tries connecting or re-connecting to the database"""
if not os.path.exists(SQLITE_FILE):
self.session = None
else:
self.session = utils.session_try_readonly(INFO.type(), SQLITE_FILE)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def is_valid(self):
"""Returns if a valid session has been opened for reading the database"""
return self.session is not None
def files(self, directory=None, extension=None,
support=Attack.attack_support_choices,
protocol='grandtest',
groups=Client.set_choices,
cls=('attack', 'real'),
light=File.light_choices):
"""Returns a set of filenames for the specific query by the user.
Keyword Parameters:
directory
A directory name that will be prepended to the final filepath returned
extension
A filename extension that will be appended to the final filepath returned
support
One of the valid support types as returned by attack_supports() or all,
as a tuple. If you set this parameter to an empty string or the value
None, we use reset it to the default, which is to get all.
protocol
The protocol for the attack. One of the ones returned by protocols(). If
you set this parameter to an empty string or the value None, we use reset
it to the default, "grandtest".
groups
One of the protocolar subgroups of data as returned by groups() or a
tuple with several of them. If you set this parameter to an empty string
or the value None, we use reset it to the default which is to get all.
cls
Either "attack", "real", "enroll" or any combination of those (in a
tuple). Defines the class of data to be retrieved. If you set this
parameter to an empty string or the value None, we use reset it to the
default, ("real", "attack").
light
One of the lighting conditions as returned by lights() or a combination
of the two (in a tuple), which is also the default.
Returns: A dictionary containing the resolved filenames considering all
the filtering criteria. The keys of the dictionary are unique identities
for each file in the replay attack database. Conserve these numbers if you
wish to save processing results later on.
"""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def check_validity(l, obj, valid, default):
"""Checks validity of user input data against a set of valid values"""
if not l: return default
elif isinstance(l, str): return check_validity((l,), obj, valid, default)
for k in l:
if k not in valid:
raise RuntimeError, 'Invalid %s "%s". Valid values are %s, or lists/tuples of those' % (obj, k, valid)
return l
def make_path(stem, directory, extension):
if not extension: extension = ''
if directory: return os.path.join(directory, stem + extension)
return stem + extension
# check if groups set are valid
VALID_GROUPS = self.groups()
groups = check_validity(groups, "group", VALID_GROUPS, VALID_GROUPS)
# check if supports set are valid
VALID_SUPPORTS = self.attack_supports()
support = check_validity(support, "support", VALID_SUPPORTS, VALID_SUPPORTS)
# by default, do NOT grab enrollment data from the database
VALID_CLASSES = ('real', 'attack', 'enroll')
cls = check_validity(cls, "class", VALID_CLASSES, ('real', 'attack'))
# check protocol validity
if not protocol: protocol = 'grandtest' #default
VALID_PROTOCOLS = self.protocols()
if protocol not in VALID_PROTOCOLS:
raise RuntimeError, 'Invalid protocol "%s". Valid values are %s' % \
(protocol, VALID_PROTOCOLS)
# resolve protocol object
protocol = self.protocol(protocol)
# checks if the light is valid
VALID_LIGHTS = self.lights()
light = check_validity(light, "light", VALID_LIGHTS, VALID_LIGHTS)
# now query the database
retval = {}
# real-accesses are simpler to query
if 'enroll' in cls:
q = self.session.query(RealAccess).with_lockmode('read').join(File).join(Client).filter(Client.set.in_(groups)).filter(RealAccess.purpose=='enroll').filter(File.light.in_(light)).order_by(Client.id)
for key, value in [(k.file.id, k.file.path) for k in q]:
retval[key] = make_path(str(value), directory, extension)
# real-accesses are simpler to query
if 'real' in cls:
q = self.session.query(RealAccess).with_lockmode('read').join(File).join(Client).filter(RealAccess.protocols.contains(protocol)).filter(Client.set.in_(groups)).filter(File.light.in_(light)).order_by(Client.id)
for key, value in [(k.file.id, k.file.path) for k in q]:
retval[key] = make_path(str(value), directory, extension)
# attacks will have to be filtered a little bit more
if 'attack' in cls:
q = self.session.query(Attack).with_lockmode('read').join(File).join(Client).filter(Attack.protocols.contains(protocol)).filter(Client.set.in_(groups)).filter(Attack.attack_support.in_(support)).filter(File.light.in_(light)).order_by(Client.id)
for key, value in [(k.file.id, k.file.path) for k in q]:
retval[key] = make_path(str(value), directory, extension)
return retval
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def faces(self, filenames, directory=None):
"""Queries the files containing the face locations for the frames in the videos specified by the input parameter filenames
Keyword parameters:
filenames
The filenames of the videos. This object should be a python iterable (such as a tuple or list).
directory
A directory name that will be prepended to the final filepaths returned. The face locations files should be located in this directory
Returns:
A list of filenames with face locations. The face location files contain the following information, tab delimited:
* Frame number
* Bounding box top-left X coordinate
* Bounding box top-left Y coordinate
* Bounding box width
* Bounding box height
There is one row for each frame, and not all the frames contain detected faces
"""
if directory: return [os.path.join(directory, stem + '.faces') for stem in filenames]
return [stem + '.faces' for stem in filenames]
def faces_ids(self, ids, directory=None):
"""Queries the files containing the face locations for the frames in the videos specified by the input parameter ids
Keyword parameters:
ids
The ids of the objects in the database table "file". This object should be a python iterable (such as a tuple or list).
directory
A directory name that will be prepended to the final filepath returned. The face locations files should be located in this directory
Returns:
A list of filenames with face locations. For description on the face locations file format, see the documentation for faces()
"""
if not directory:
directory = ''
facespaths = self.paths(ids, prefix=directory, suffix='.faces')
return facespaths
def protocols(self):
"""Returns the names of all registered protocols"""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
return tuple([k.name for k in self.session.query(Protocol).with_lockmode('read')])
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
return self.session.query(Protocol).with_lockmode('read').filter(Protocol.name==name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
return self.session.query(Protocol).with_lockmode('read').filter(Protocol.name==name).one()
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def groups(self):
"""Returns the names of all registered groups"""
return Client.set_choices
def lights(self):
"""Returns light variations available in the database"""
return File.light_choices
def attack_supports(self):
"""Returns attack supports available in the database"""
return Attack.attack_support_choices
def attack_devices(self):
"""Returns attack devices available in the database"""
return Attack.attack_device_choices
def attack_sampling_devices(self):
"""Returns sampling devices available in the database"""
return Attack.sample_device_choices
def attack_sample_types(self):
"""Returns attack sample types available in the database"""
return Attack.sample_type_choices
def paths(self, ids, prefix='', suffix=''):
"""Returns a full file paths considering particular file ids, a given
directory and an extension
Keyword Parameters:
id
The ids of the object in the database table "file". This object should be
a python iterable (such as a tuple or list).
prefix
The bit of path to be prepended to the filename stem
suffix
The extension determines the suffix that will be appended to the filename
stem.
Returns a list (that may be empty) of the fully constructed paths given the
file ids.
"""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
fobj = self.session.query(File).with_lockmode('read').filter(File.id.in_(ids))
retval = []
for p in ids:
retval.extend([os.path.join(prefix, str(k.path) + suffix)
for k in fobj if k.id == p])
return retval
def reverse(self, paths):
"""Reverses the lookup: from certain stems, returning file ids
Keyword Parameters:
paths
The filename stems I'll query for. This object should be a python
iterable (such as a tuple or list)
Returns a list (that may be empty).
"""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
fobj = self.session.query(File).with_lockmode('read').filter(File.path.in_(paths))
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
retval = []
for p in paths:
retval.extend([k.id for k in fobj if k.path == p])
return retval
def save_one(self, id, obj, directory, extension):
"""Saves a single object supporting the bob save() protocol.
This method will call save() on the the given object using the correct
database filename stem for the given id.
Keyword Parameters:
id
The id of the object in the database table "file".
obj
The object that needs to be saved, respecting the bob save() protocol.
directory
This is the base directory to which you want to save the data. The
directory is tested for existence and created if it is not there with
os.makedirs()
extension
The extension determines the way each of the arrays will be saved.
"""
if not self.is_valid():
raise RuntimeError, "Database '%s' cannot be found at expected location '%s'. Create it and then try re-connecting using Database.connect()" % (INFO.name(), SQLITE_FILE)
fobj = self.session.query(File).with_lockmode('read').filter_by(id=id).one()
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
fullpath = os.path.join(directory, str(fobj.path) + extension)
fulldir = os.path.dirname(fullpath)
utils.makedirs_safe(fulldir)
save(obj, fullpath)
def save(self, data, directory, extension):
"""This method takes a dictionary of blitz arrays or bob.database.Array's
and saves the data respecting the original arrangement as returned by
files().
Keyword Parameters:
data
A dictionary with two keys 'real' and 'attack', each containing a
dictionary mapping file ids from the original database to an object that
supports the bob "save()" protocol.
directory
This is the base directory to which you want to save the data. The
directory is tested for existence and created if it is not there with
os.makedirs()
extension
The extension determines the way each of the arrays will be saved.
"""
for key, value in data:
self.save_one(key, value, directory, extension)