Commit 8a0f87b0 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

The SQLiteDatabase now accepts original_directory and original_extension

parent af3d5ecf
Pipeline #9195 passed with stages
in 34 minutes and 42 seconds
......@@ -31,6 +31,7 @@ logger = bob.core.log.setup("bob.db.nist_sre12")
SQLITE_FILE = Interface().files()[0]
class Database(bob.db.base.SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
......@@ -38,19 +39,21 @@ class Database(bob.db.base.SQLiteDatabase):
and for the data itself inside the database.
"""
def __init__(self, original_directory = None, original_extension = ".sph"):
def __init__(self, original_directory=None, original_extension=".sph"):
# call base class constructors
bob.db.base.SQLiteDatabase.__init__(self, SQLITE_FILE, File)
bob.db.base.SQLiteDatabase.__init__(
self, SQLITE_FILE, File, original_directory, original_extension)
def groups(self, protocol=None):
"""Returns the names of all registered groups"""
return ProtocolPurpose.group_choices # Same as Model.group_choices for this database
# Same as Model.group_choices for this database
return ProtocolPurpose.group_choices
def genders(self):
"""Returns the names of all registered groups"""
return ('male','female')
return ('male', 'female')
def clients(self, protocol=None, groups=None, filter_ids_unknown=True):
"""Returns a set of clients for the specific query by the user.
......@@ -71,7 +74,6 @@ class Database(bob.db.base.SQLiteDatabase):
return self.models(protocol, groups, filter_ids_unknown)
def models(self, protocol=None, groups=None, filter_ids_unknown=True):
"""Returns a set of models for the specific query by the user.
......@@ -88,20 +90,22 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list containing all the models belonging to the given group.
"""
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
groups = self.check_parameters_for_validity(groups, "groups", self.groups())
# List of the clients
retval = []
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
groups = self.check_parameters_for_validity(
groups, "groups", self.groups())
# List of the clients
retval = []
q = self.query(Model).join((ProtocolPurpose, Model.protocolPurposes)).join((Protocol, ProtocolPurpose.protocol)).\
filter(Protocol.name.in_(protocol)).filter(ProtocolPurpose.sgroup.in_(groups)).filter(ProtocolPurpose.purpose == 'enroll')
if filter_ids_unknown == True:
q = q.filter(not_(Model.id.in_(['F_ID_X', 'M_ID_X'])))
q = q.order_by(Model.id)
retval += list(q)
return list(set(retval))
filter(Protocol.name.in_(protocol)).filter(ProtocolPurpose.sgroup.in_(
groups)).filter(ProtocolPurpose.purpose == 'enroll')
if filter_ids_unknown == True:
q = q.filter(not_(Model.id.in_(['F_ID_X', 'M_ID_X'])))
q = q.order_by(Model.id)
retval += list(q)
return list(set(retval))
def model_ids(self, protocol=None, groups=None, filter_ids_unknown=True):
"""Returns a list of model ids for the specific query by the user.
......@@ -143,14 +147,13 @@ class Database(bob.db.base.SQLiteDatabase):
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
return self.query(Model).filter(Model.id==id).count() != 0
return self.query(Model).filter(Model.id == id).count() != 0
def client(self, id):
"""Returns the client object in the database given a certain id. Raises
an error if that does not exist."""
return self.query(Model).filter(Model.client_id==id).one()
return self.query(Model).filter(Model.client_id == id).one()
def get_client_id_from_model_id(self, model_id, **kwargs):
"""Returns the client_id attached to the given model_id
......@@ -162,10 +165,9 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: The client_id attached to the given model_id
"""
model = self.query(Model).filter(Model.model_id==model_id).one()
model = self.query(Model).filter(Model.model_id == model_id).one()
return model.client_id
def objects(self, protocol=None, purposes=None, model_ids=None, groups=None, gender=None):
"""Returns a set of filenames for the specific query by the user.
WARNING: Files used as impostor access for several different models are
......@@ -200,8 +202,10 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list of files which have the given properties.
"""
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names(), 'core-all')
purposes = self.check_parameters_for_validity(purposes, "purpose", self.purposes())
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names(), 'core-all')
purposes = self.check_parameters_for_validity(
purposes, "purpose", self.purposes())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
import six
......@@ -217,113 +221,124 @@ class Database(bob.db.base.SQLiteDatabase):
if model_ids == ():
if gender == None:
q1l = self.query(ModelEnrollLink).join(Protocol).filter(Protocol.name.in_(protocol)).distinct().all()
q1l = self.query(ModelEnrollLink).join(Protocol).filter(
Protocol.name.in_(protocol)).distinct().all()
else:
q1l = self.query(ModelEnrollLink).join(Model).join(Protocol).filter(and_(Protocol.name.in_(protocol),Model.gender == gender )).distinct().all()
if len(q1l)>0:
file_ids_big = [ x.file_id for x in q1l]
q1l = self.query(ModelEnrollLink).join(Model).join(Protocol).filter(
and_(Protocol.name.in_(protocol), Model.gender == gender)).distinct().all()
if len(q1l) > 0:
file_ids_big = [x.file_id for x in q1l]
length = len(file_ids_big)
batches = int(length / 999) + 1 # 999 is the limit of sqlite in in_
batches = int(length / 999) + 1 # 999 is the limit of sqlite in in_
for i in range(batches):
logger.info('querying batch {} of {} batches'.format(i+1, batches))
file_ids = file_ids_big[i*999:(i+1)*999]
logger.info(
'querying batch {} of {} batches'.format(i + 1, batches))
file_ids = file_ids_big[i * 999:(i + 1) * 999]
if not file_ids:
continue
q = self.query(File).filter(File.id.in_(file_ids)).order_by(File.id)
if q.count()>0:
q = self.query(File).filter(
File.id.in_(file_ids)).order_by(File.id)
if q.count() > 0:
retval += list(q)
else:
if gender == None:
q1l = self.query(ModelEnrollLink).join(Protocol).filter(and_(ModelEnrollLink.model_id.in_(model_ids), Protocol.name.in_(protocol) )).all()
q1l = self.query(ModelEnrollLink).join(Protocol).filter(
and_(ModelEnrollLink.model_id.in_(model_ids), Protocol.name.in_(protocol))).all()
else:
q1l = self.query(ModelEnrollLink).join(Model).join(Protocol).filter(and_(ModelEnrollLink.model_id.in_(model_ids), Protocol.name.in_(protocol),Model.gender == gender )).distinct().all()
if len(q1l)>0:
file_ids_big = [ x.file_id for x in q1l]
q1l = self.query(ModelEnrollLink).join(Model).join(Protocol).filter(and_(ModelEnrollLink.model_id.in_(
model_ids), Protocol.name.in_(protocol), Model.gender == gender)).distinct().all()
if len(q1l) > 0:
file_ids_big = [x.file_id for x in q1l]
length = len(file_ids_big)
batches = int(length / 999) + 1 # 999 is the limit of sqlite in in_
batches = int(length / 999) + 1 # 999 is the limit of sqlite in in_
for i in range(batches):
logger.info('querying batch {} of {} batches'.format(i+1, batches))
file_ids = file_ids_big[i*999:(i+1)*999]
logger.info(
'querying batch {} of {} batches'.format(i + 1, batches))
file_ids = file_ids_big[i * 999:(i + 1) * 999]
if not file_ids:
continue
q = self.query(File).filter(File.id.in_(file_ids)).order_by(File.id)
if q.count()>0:
q = self.query(File).filter(
File.id.in_(file_ids)).order_by(File.id)
if q.count() > 0:
retval += list(q)
if('probe' in purposes):
if model_ids == ():
if gender == None:
q = self.query(File).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.purpose == 'probe'))
if q.count()>0:
q = self.query(File).join((ProtocolPurpose, File.protocolPurposes)).join(
Protocol).filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.purpose == 'probe'))
if q.count() > 0:
retval += list(q)
else:
# import ipdb ; ipdb.set_trace()
q1l = self.query(ModelProbeLink).join(Model).join(Protocol).filter(and_(Protocol.name.in_(protocol), Model.gender == gender )).all()
if len(q1l)>0:
file_ids_big = list(set([ x.file_id for x in q1l]))
# import ipdb ; ipdb.set_trace()
q1l = self.query(ModelProbeLink).join(Model).join(Protocol).filter(
and_(Protocol.name.in_(protocol), Model.gender == gender)).all()
if len(q1l) > 0:
file_ids_big = list(set([x.file_id for x in q1l]))
length = len(file_ids_big)
batches = int(length / 999) + 1 # 999 is the limit of sqlite in in_
# 999 is the limit of sqlite in in_
batches = int(length / 999) + 1
for i in range(batches):
logger.info('querying batch {} of {} batches'.format(i+1, batches))
file_ids = file_ids_big[i*999:(i+1)*999]
logger.info(
'querying batch {} of {} batches'.format(i + 1, batches))
file_ids = file_ids_big[i * 999:(i + 1) * 999]
if not file_ids:
continue
q = self.query(File).filter(File.id.in_(file_ids)).order_by(File.id)
if q.count()>0:
q = self.query(File).filter(
File.id.in_(file_ids)).order_by(File.id)
if q.count() > 0:
retval += list(q)
else:
if gender == None:
q1l = self.query(ModelProbeLink).join(Protocol).filter(and_(ModelProbeLink.model_id.in_(model_ids), Protocol.name.in_(protocol) )).distinct().all()
q1l = self.query(ModelProbeLink).join(Protocol).filter(and_(
ModelProbeLink.model_id.in_(model_ids), Protocol.name.in_(protocol))).distinct().all()
else:
q1l = self.query(ModelProbeLink).join(Protocol).filter(and_(ModelProbeLink.model_id.in_(model_ids), Protocol.name.in_(protocol), Model.gender == gender )).distinct().all()
if len(q1l)>0:
file_ids_big = list(set( [x.file_id for x in q1l] ))
q1l = self.query(ModelProbeLink).join(Protocol).filter(and_(ModelProbeLink.model_id.in_(
model_ids), Protocol.name.in_(protocol), Model.gender == gender)).distinct().all()
if len(q1l) > 0:
file_ids_big = list(set([x.file_id for x in q1l]))
length = len(file_ids_big)
batches = int(length / 999) + 1 # 999 is the limit of sqlite in in_
batches = int(length / 999) + 1 # 999 is the limit of sqlite in in_
for i in range(batches):
logger.info('querying batch {} of {} batches'.format(i+1, batches))
file_ids = file_ids_big[i*999:(i+1)*999]
logger.info(
'querying batch {} of {} batches'.format(i + 1, batches))
file_ids = file_ids_big[i * 999:(i + 1) * 999]
if not file_ids:
continue
q = self.query(File).filter(File.id.in_(file_ids)).order_by(File.id)
if q.count()>0:
q = self.query(File).filter(
File.id.in_(file_ids)).order_by(File.id)
if q.count() > 0:
retval += list(q)
return list(set(retval)) # To remove duplicates
return list(set(retval)) # To remove duplicates
def protocol_names(self):
"""Returns all registered protocol names"""
return [str(p.name) for p in self.protocols()]
def protocols(self):
"""Returns all registered protocols"""
return list(self.query(Protocol))
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
return self.query(Protocol).filter(Protocol.name==name).count() != 0
return self.query(Protocol).filter(Protocol.name == name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
return self.query(Protocol).filter(Protocol.name==name).one()
return self.query(Protocol).filter(Protocol.name == name).one()
def protocol_purposes(self):
"""Returns all registered protocol purposes"""
return list(self.query(ProtocolPurpose))
def purposes(self):
"""Returns the list of allowed purposes"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment