Commit 71e5bd4b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

The SQLiteDatabase now accepts original_directory and original_extension

parent 8d0f994e
Pipeline #9192 passed with stages
in 35 minutes and 10 seconds
......@@ -13,6 +13,7 @@ import bob.db.base
SQLITE_FILE = Interface().files()[0]
class Database(bob.db.base.SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
......@@ -20,20 +21,23 @@ class Database(bob.db.base.SQLiteDatabase):
and for the data itself inside the database.
"""
def __init__(self, original_directory = None, original_extension = db_file_extension):
def __init__(self, original_directory=None, original_extension=db_file_extension):
# call base class constructor
super(Database, self).__init__(SQLITE_FILE, File)
self.original_directory = original_directory
self.original_extension = original_extension
super(Database, self).__init__(SQLITE_FILE, File,
original_directory, original_extension)
def __group_replace_eval_by_genuine__(self, l):
"""Replace 'eval' by 'Genuine' and returns the new list"""
if not l: return l
elif isinstance(l, six.string_types): return self.__group_replace_eval_by_genuine__((l,))
if not l:
return l
elif isinstance(l, six.string_types):
return self.__group_replace_eval_by_genuine__((l,))
l2 = []
for val in l:
if (val == 'eval'): l2.append('Genuine')
elif (val in Client.type_choices): l2.append(val)
if (val == 'eval'):
l2.append('Genuine')
elif (val in Client.type_choices):
l2.append(val)
return tuple(set(l2))
def groups(self, protocol=None):
......@@ -68,7 +72,8 @@ class Database(bob.db.base.SQLiteDatabase):
# List of the clients
#q = self.query(Client)
if (protocol):
q = self.query(Client).join(File).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).filter(and_(Protocol.name.in_((protocol,)), ProtocolPurpose.sgroup.in_((groups,))))
q = self.query(Client).join(File).join((ProtocolPurpose, File.protocolPurposes)).join(
Protocol).filter(and_(Protocol.name.in_((protocol,)), ProtocolPurpose.sgroup.in_((groups,))))
else:
q = self.query(Client)
"""if groups:
......@@ -99,7 +104,8 @@ class Database(bob.db.base.SQLiteDatabase):
# List of the clients
#q = self.query(Client)
if (protocol):
q = self.query(Client).join(File).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).filter(and_(Protocol.name.in_((protocol,)), ProtocolPurpose.sgroup.in_((groups,))))
q = self.query(Client).join(File).join((ProtocolPurpose, File.protocolPurposes)).join(
Protocol).filter(and_(Protocol.name.in_((protocol,)), ProtocolPurpose.sgroup.in_((groups,))))
else:
q = self.query(Client)
"""if groups:
......@@ -132,13 +138,13 @@ class Database(bob.db.base.SQLiteDatabase):
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
return self.query(Client).filter(Client.id==id).count() != 0
return self.query(Client).filter(Client.id == id).count() != 0
def client(self, id):
"""Returns the client object in the database given a certain id. Raises
an error if that does not exist."""
return self.query(Client).filter(Client.id==id).one()
return self.query(Client).filter(Client.id == id).one()
def objects(self, protocol=None, purposes=None, model_ids=None, groups=None,
classes=None):
......@@ -174,15 +180,18 @@ class Database(bob.db.base.SQLiteDatabase):
"""
#groups = self.__group_replace_alias_clients__(groups)
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(purposes, "purpose", self.purposes())
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(
purposes, "purpose", self.purposes())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
classes = self.check_parameters_for_validity(classes, "class", ('client', 'impostor'))
classes = self.check_parameters_for_validity(
classes, "class", ('client', 'impostor'))
import collections
if(model_ids is None):
model_ids = ()
elif(not isinstance(model_ids,collections.Iterable)):
elif(not isinstance(model_ids, collections.Iterable)):
model_ids = (model_ids,)
# Now query the database
......@@ -191,7 +200,8 @@ class Database(bob.db.base.SQLiteDatabase):
if ('eval' in groups):
if('enrol' in purposes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'enrol'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'enrol'))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.shot_id)
......@@ -200,7 +210,8 @@ class Database(bob.db.base.SQLiteDatabase):
if('probe' in purposes):
if('client' in classes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'probe'))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.shot_id)
......@@ -208,13 +219,14 @@ class Database(bob.db.base.SQLiteDatabase):
if('impostor' in classes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'probe'))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.shot_id)
retval += list(q)
return list(set(retval)) # To remove duplicates
return list(set(retval)) # To remove duplicates
def protocol_names(self):
"""Returns all registered protocol names"""
......@@ -231,13 +243,13 @@ class Database(bob.db.base.SQLiteDatabase):
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
return self.query(Protocol).filter(Protocol.name==name).count() != 0
return self.query(Protocol).filter(Protocol.name == name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
return self.query(Protocol).filter(Protocol.name==name).one()
return self.query(Protocol).filter(Protocol.name == name).one()
def protocol_purposes(self):
"""Returns all registered protocol purposes"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment