Skip to content
Snippets Groups Projects
Commit 687b85af authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

The SQLiteDatabase now accepts original_directory and original_extension

parent 20d3c243
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -31,6 +31,7 @@ import bob.db.base
SQLITE_FILE = Interface().files()[0]
class Database(bob.db.base.SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
......@@ -38,36 +39,37 @@ class Database(bob.db.base.SQLiteDatabase):
and for the data itself inside the database.
"""
def __init__(self, original_directory = None, original_extension = '.jpg', annotation_type = None):
def __init__(self, original_directory=None, original_extension='.jpg', annotation_type=None):
# call base class constructor
super(Database, self).__init__(SQLITE_FILE, File)
self.original_directory = original_directory
self.original_extension = original_extension
super(Database, self).__init__(SQLITE_FILE, File,
original_directory, original_extension)
self.m_valid_protocols = ('view1', 'fold1', 'fold2', 'fold3', 'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10')
self.m_valid_protocols = ('view1', 'fold1', 'fold2', 'fold3',
'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10')
self.m_valid_groups = ('world', 'dev', 'eval')
self.m_valid_purposes = ('enroll', 'probe')
self.m_valid_classes = ('matched', 'client', 'unmatched', 'impostor')
self.m_subworld_counts = {'onefolds':1, 'twofolds':2, 'threefolds':3, 'fourfolds':4, 'fivefolds':5, 'sixfolds':6, 'sevenfolds':7}
self.m_subworld_counts = {'onefolds': 1, 'twofolds': 2, 'threefolds': 3,
'fourfolds': 4, 'fivefolds': 5, 'sixfolds': 6, 'sevenfolds': 7}
self.m_valid_types = ('restricted', 'unrestricted')
self.m_valid_annotation_types = ('idiap', 'funneled')
if annotation_type is not None:
self.m_annotation_type = self.check_parameter_for_validity(annotation_type, "annotation type", self.m_valid_annotation_types)
self.m_annotation_type = self.check_parameter_for_validity(
annotation_type, "annotation type", self.m_valid_annotation_types)
else:
self.m_annotation_type = None
def __eval__(self, fold):
return int(fold[4:])
def __dev__(self, eval):
# take the two parts of the training set (the ones before the eval set) for dev
# take the two parts of the training set (the ones before the eval set)
# for dev
return ((eval + 7) % 10 + 1, (eval + 8) % 10 + 1)
def __dev_for__(self, fold):
return ["fold%d"%f for f in self.__dev__(self.__eval__(fold))]
return ["fold%d" % f for f in self.__dev__(self.__eval__(fold))]
def __world_for__(self, fold, subworld):
# the training sets for each fold are composed of all folds
......@@ -78,8 +80,7 @@ class Database(bob.db.base.SQLiteDatabase):
world = []
for i in range(world_count):
world.append((eval + i) % 10 + 1)
return ["fold%d"%f for f in world]
return ["fold%d" % f for f in world]
def protocol_names(self):
"""Returns the names of the valid protocols."""
......@@ -108,7 +109,6 @@ class Database(bob.db.base.SQLiteDatabase):
s = set([a.annotation_type for a in self.query(Annotation)])
return [str(t) for t in s]
def clients(self, protocol=None, groups=None, subworld='sevenfolds', world_type='unrestricted'):
"""Returns a list of Client objects for the specific query by the user.
......@@ -135,11 +135,15 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list containing all Client objects which have the desired properties.
"""
protocols = self.check_parameters_for_validity(protocol, 'protocol', self.m_valid_protocols)
groups = self.check_parameters_for_validity(groups, 'group', self.m_valid_groups)
protocols = self.check_parameters_for_validity(
protocol, 'protocol', self.m_valid_protocols)
groups = self.check_parameters_for_validity(
groups, 'group', self.m_valid_groups)
if subworld != None:
subworld = self.check_parameter_for_validity(subworld, 'sub-world', list(self.m_subworld_counts.keys()))
world_type = self.check_parameter_for_validity(world_type, 'training type', self.m_valid_types)
subworld = self.check_parameter_for_validity(
subworld, 'sub-world', list(self.m_subworld_counts.keys()))
world_type = self.check_parameter_for_validity(
world_type, 'training type', self.m_valid_types)
queries = []
......@@ -148,46 +152,46 @@ class Database(bob.db.base.SQLiteDatabase):
if protocol == 'view1':
if 'world' in groups:
if world_type == 'restricted':
queries.append(\
self.query(Client).join(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).\
filter(Pair.protocol == 'train').\
order_by(Client.id))
queries.append(
self.query(Client).join(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).
filter(Pair.protocol == 'train').
order_by(Client.id))
else:
queries.append(\
self.query(Client).join(File).join(People).\
filter(People.protocol == 'train').\
order_by(Client.id))
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol == 'train').
order_by(Client.id))
if 'dev' in groups:
queries.append(\
self.query(Client).join(File).join(People).\
filter(People.protocol == 'test').\
order_by(Client.id))
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol == 'test').
order_by(Client.id))
else:
if 'world' in groups:
# select training set for the given fold
trainset = self.__world_for__(protocol, subworld)
if world_type == 'restricted':
queries.append(\
self.query(Client).join(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).\
filter(Pair.protocol.in_(trainset)).\
order_by(Client.id))
queries.append(
self.query(Client).join(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).
filter(Pair.protocol.in_(trainset)).
order_by(Client.id))
else:
queries.append(\
self.query(Client).join(File).join(People).\
filter(People.protocol.in_(trainset)).\
order_by(Client.id))
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol.in_(trainset)).
order_by(Client.id))
if 'dev' in groups:
# select development set for the given fold
devset = self.__dev_for__(protocol)
queries.append(\
self.query(Client).join(File).join(People).\
filter(People.protocol.in_(devset)).\
order_by(Client.id))
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol.in_(devset)).
order_by(Client.id))
if 'eval' in groups:
queries.append(\
self.query(Client).join(File).join(People).\
filter(People.protocol == protocol).\
order_by(Client.id))
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol == protocol).
order_by(Client.id))
# all queries are made; now collect the clients
retval = []
......@@ -197,7 +201,6 @@ class Database(bob.db.base.SQLiteDatabase):
return self.uniquify(retval)
def models(self, protocol=None, groups=None):
"""Returns a list of File objects (there are multiple models per client) for the specific query by the user.
For the 'dev' and 'eval' groups, the first element of each pair is extracted.
......@@ -214,8 +217,10 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list containing all File objects which have the desired properties.
"""
protocols = self.check_parameters_for_validity(protocol, 'protocol', self.m_valid_protocols)
groups = self.check_parameters_for_validity(groups, 'group', ('dev', 'eval'))
protocols = self.check_parameters_for_validity(
protocol, 'protocol', self.m_valid_protocols)
groups = self.check_parameters_for_validity(
groups, 'group', ('dev', 'eval'))
# the restricted case...
queries = []
......@@ -227,18 +232,18 @@ class Database(bob.db.base.SQLiteDatabase):
queries.append(\
# enroll files
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).\
filter(Pair.protocol == 'test'))
filter(Pair.protocol == 'test'))
else:
if 'dev' in groups:
# select development set for the given fold
devset = self.__dev_for__(protocol)
queries.append(\
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).\
filter(Pair.protocol.in_(devset)))
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol.in_(devset)))
if 'eval' in groups:
queries.append(\
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).\
filter(Pair.protocol == protocol))
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol == protocol))
# all queries are made; now collect the files
retval = []
......@@ -247,7 +252,6 @@ class Database(bob.db.base.SQLiteDatabase):
return self.uniquify(retval)
def model_ids(self, protocol=None, groups=None):
"""Returns a list of model ids for the specific query by the user.
For the 'dev' and 'eval' groups, the first element of each pair is extracted.
......@@ -263,8 +267,7 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list containing all model ids which have the desired properties.
"""
return [file.id for file in self.models(protocol,groups)]
return [file.id for file in self.models(protocol, groups)]
def get_client_id_from_file_id(self, file_id, **kwargs):
"""Returns the client_id (real client id) attached to the given file_id
......@@ -279,12 +282,11 @@ class Database(bob.db.base.SQLiteDatabase):
self.assert_validity()
q = self.query(File).\
filter(File.id == file_id)
filter(File.id == file_id)
assert q.count() == 1
return q.first().client_id
def get_client_id_from_model_id(self, model_id, **kwargs):
"""Returns the client_id (real client id) attached to the given model id
......@@ -304,7 +306,6 @@ class Database(bob.db.base.SQLiteDatabase):
# since there is one model per file, we can re-use the function above.
return self.get_client_id_from_file_id(model_id)
def objects(self, protocol=None, model_ids=None, groups=None, purposes=None, subworld='sevenfolds', world_type='unrestricted'):
"""Returns a list of File objects for the specific query by the user.
......@@ -337,15 +338,20 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list of File objects considering all the filtering criteria.
"""
protocols = self.check_parameters_for_validity(protocol, "protocol", self.m_valid_protocols)
groups = self.check_parameters_for_validity(groups, "group", self.m_valid_groups)
purposes = self.check_parameters_for_validity(purposes, "purpose", self.m_valid_purposes)
world_type = self.check_parameter_for_validity(world_type, 'training type', self.m_valid_types)
protocols = self.check_parameters_for_validity(
protocol, "protocol", self.m_valid_protocols)
groups = self.check_parameters_for_validity(
groups, "group", self.m_valid_groups)
purposes = self.check_parameters_for_validity(
purposes, "purpose", self.m_valid_purposes)
world_type = self.check_parameter_for_validity(
world_type, 'training type', self.m_valid_types)
if subworld != None:
subworld = self.check_parameter_for_validity(subworld, 'sub-world', list(self.m_subworld_counts.keys()))
subworld = self.check_parameter_for_validity(
subworld, 'sub-world', list(self.m_subworld_counts.keys()))
if(isinstance(model_ids,six.string_types)):
if(isinstance(model_ids, six.string_types)):
model_ids = (model_ids,)
queries = []
......@@ -357,25 +363,25 @@ class Database(bob.db.base.SQLiteDatabase):
if 'world' in groups:
# training files of view1
if world_type == 'restricted':
queries.append(\
self.query(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).\
filter(Pair.protocol == 'train'))
queries.append(
self.query(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).
filter(Pair.protocol == 'train'))
else:
queries.append(\
self.query(File).join(People).\
filter(People.protocol == 'train'))
queries.append(
self.query(File).join(People).
filter(People.protocol == 'train'))
if 'dev' in groups:
# test files of view1
if 'enroll' in purposes:
queries.append(\
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).\
filter(Pair.protocol == 'test'))
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol == 'test'))
if 'probe' in purposes:
probe_queries.append(\
self.query(File).\
join((Pair, File.id == Pair.probe_file_id)).\
join((file_alias, Pair.enroll_file_id == file_alias.id)).\
filter(Pair.protocol == 'test'))
probe_queries.append(
self.query(File).
join((Pair, File.id == Pair.probe_file_id)).
join((file_alias, Pair.enroll_file_id == file_alias.id)).
filter(Pair.protocol == 'test'))
else:
# view 2
......@@ -383,40 +389,40 @@ class Database(bob.db.base.SQLiteDatabase):
# world set of current fold of view 2
trainset = self.__world_for__(protocol, subworld)
if world_type == 'restricted':
queries.append(\
self.query(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).\
filter(Pair.protocol.in_(trainset)))
queries.append(
self.query(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).
filter(Pair.protocol.in_(trainset)))
else:
queries.append(\
self.query(File).join(People).\
filter(People.protocol.in_(trainset)))
queries.append(
self.query(File).join(People).
filter(People.protocol.in_(trainset)))
if 'dev' in groups:
# development set of current fold of view 2
devset = self.__dev_for__(protocol)
if 'enroll' in purposes:
queries.append(\
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).\
filter(Pair.protocol.in_(devset)))
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol.in_(devset)))
if 'probe' in purposes:
probe_queries.append(\
self.query(File).\
join((Pair, File.id == Pair.probe_file_id)).\
join((file_alias, file_alias.id == Pair.enroll_file_id)).\
filter(Pair.protocol.in_(devset)))
probe_queries.append(
self.query(File).
join((Pair, File.id == Pair.probe_file_id)).
join((file_alias, file_alias.id == Pair.enroll_file_id)).
filter(Pair.protocol.in_(devset)))
if 'eval' in groups:
# evaluation set of current fold of view 2; this is the REAL fold
if 'enroll' in purposes:
queries.append(\
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).\
filter(Pair.protocol == protocol))
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol == protocol))
if 'probe' in purposes:
probe_queries.append(\
self.query(File).\
join((Pair, File.id == Pair.probe_file_id)).\
join((file_alias, file_alias.id == Pair.enroll_file_id)).\
filter(Pair.protocol == protocol))
probe_queries.append(
self.query(File).
join((Pair, File.id == Pair.probe_file_id)).
join((file_alias, file_alias.id == Pair.enroll_file_id)).
filter(Pair.protocol == protocol))
retval = []
for query in queries:
......@@ -434,7 +440,6 @@ class Database(bob.db.base.SQLiteDatabase):
return self.uniquify(retval)
def pairs(self, protocol=None, groups=None, classes=None, subworld='sevenfolds'):
"""Queries a list of Pair's of files.
......@@ -459,14 +464,18 @@ class Database(bob.db.base.SQLiteDatabase):
def default_query():
return self.query(Pair).\
join((File1, File1.id == Pair.enroll_file_id)).\
join((File2, File2.id == Pair.probe_file_id))
protocol = self.check_parameter_for_validity(protocol, "protocol", self.m_valid_protocols)
groups = self.check_parameters_for_validity(groups, "group", self.m_valid_groups)
classes = self.check_parameters_for_validity(classes, "class", self.m_valid_classes)
join((File1, File1.id == Pair.enroll_file_id)).\
join((File2, File2.id == Pair.probe_file_id))
protocol = self.check_parameter_for_validity(
protocol, "protocol", self.m_valid_protocols)
groups = self.check_parameters_for_validity(
groups, "group", self.m_valid_groups)
classes = self.check_parameters_for_validity(
classes, "class", self.m_valid_classes)
if subworld != None:
subworld = self.check_parameter_for_validity(subworld, 'sub-world', list(self.m_subworld_counts.keys()))
subworld = self.check_parameter_for_validity(
subworld, 'sub-world', list(self.m_subworld_counts.keys()))
queries = []
File1 = aliased(File)
......@@ -516,26 +525,26 @@ class Database(bob.db.base.SQLiteDatabase):
if annotation_type is None:
annotation_type = self.m_annotation_type
annotation_type = self.check_parameters_for_validity(annotation_type, "annotation type", self.m_valid_annotation_types)
annotation_type = self.check_parameters_for_validity(
annotation_type, "annotation type", self.m_valid_annotation_types)
query = self.query(Annotation).filter(Annotation.annotation_type.in_(annotation_type)).join(File).filter(File.id==file.id)
query = self.query(Annotation).filter(Annotation.annotation_type.in_(
annotation_type)).join(File).filter(File.id == file.id)
assert query.count() == 1
annotation = query.first()
# return the annotations as returned by the call function of the Annotation object
# return the annotations as returned by the call function of the
# Annotation object
return annotation()
def t_model_ids(self, protocol, groups = 'dev', **kwargs):
def t_model_ids(self, protocol, groups='dev', **kwargs):
"""Returns the list of model ids used for T-Norm of the given protocol for the given group that satisfy your query."""
return self.uniquify(self.tmodel_ids(protocol=protocol, groups=groups, **kwargs))
def t_enroll_files(self, protocol, model_id, groups = 'dev', **kwargs):
def t_enroll_files(self, protocol, model_id, groups='dev', **kwargs):
"""Returns the list of T-Norm model enrollment File objects from the given model id of the given protocol for the given group that satisfy your query."""
return self.uniquify(self.tobjects(protocol=protocol, groups=groups, model_ids=(model_id,), **kwargs))
def z_probe_files(self, protocol, groups = 'dev', **kwargs):
def z_probe_files(self, protocol, groups='dev', **kwargs):
"""Returns the list of Z-Norm probe File objects to probe the model with the given model id of the given protocol for the given group that satisfy your query."""
return self.uniquify(self.zobjects(protocol=protocol, groups=groups, **kwargs))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment