Commit ef9b1c84 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

The SQLiteDatabase now accepts original_directory and original_extension

parent 16cb7f30
Pipeline #9191 passed with stages
in 40 minutes and 29 seconds
......@@ -15,6 +15,7 @@ import bob.db.base
SQLITE_FILE = Interface().files()[0]
class Database(bob.db.base.SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
......@@ -22,16 +23,17 @@ class Database(bob.db.base.SQLiteDatabase):
and for the data itself inside the database.
"""
def __init__(self):
bob.db.base.SQLiteDatabase.__init__(self, SQLITE_FILE, File)
def __init__(self, original_directory=None, original_extension=None):
bob.db.base.SQLiteDatabase.__init__(
self, SQLITE_FILE, File, original_directory, original_extension)
def groups(self, protocol=None):
"""Returns the names of all registered groups"""
if protocol == '1vsall': return ('world', 'dev')
else: return ('world', 'dev', 'eval')
if protocol == '1vsall':
return ('world', 'dev')
else:
return ('world', 'dev', 'eval')
def clients(self, protocol=None, groups=None):
"""Returns a set of clients for the specific query by the user.
......@@ -60,18 +62,21 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocols = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
protocols = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
retval = []
# List of the clients
if 'world' in groups:
q = self.query(Client).join((File, Client.files)).join((Protocol, File.protocols_train)).filter(Protocol.name.in_(protocols))
q = self.query(Client).join((File, Client.files)).join(
(Protocol, File.protocols_train)).filter(Protocol.name.in_(protocols))
q = q.order_by(Client.id)
retval += list(q)
if 'dev' in groups or 'eval' in groups:
q = self.query(Client).join((Model, Client.models)).join((Protocol, Model.protocol)).filter(Protocol.name.in_(protocols))
q = self.query(Client).join((Model, Client.models)).join(
(Protocol, Model.protocol)).filter(Protocol.name.in_(protocols))
q = q.filter(Model.sgroup.in_(groups))
q = q.order_by(Client.id)
retval += list(q)
......@@ -81,7 +86,6 @@ class Database(bob.db.base.SQLiteDatabase):
return retval
def client_ids(self, protocol=None, groups=None):
"""Returns a set of client ids for the specific query by the user.
......@@ -114,7 +118,6 @@ class Database(bob.db.base.SQLiteDatabase):
return [client.id for client in self.clients(protocol, groups)]
def models(self, protocol=None, groups=None):
"""Returns a set of models for the specific query by the user.
......@@ -144,19 +147,20 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocols = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
protocols = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
retval = []
if 'dev' in groups or 'eval' in groups:
# List of the clients
q = self.query(Model).join((Protocol, Model.protocol)).filter(Protocol.name.in_(protocols))
q = self.query(Model).join((Protocol, Model.protocol)
).filter(Protocol.name.in_(protocols))
q = q.filter(Model.sgroup.in_(groups)).order_by(Model.name)
retval += list(q)
return retval
def model_ids(self, protocol=None, groups=None):
"""Returns a set of models ids for the specific query by the user.
......@@ -189,19 +193,16 @@ class Database(bob.db.base.SQLiteDatabase):
return [model.name for model in self.models(protocol, groups)]
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
return self.query(Client).filter(Client.id==id).count() != 0
return self.query(Client).filter(Client.id == id).count() != 0
def client(self, id):
"""Returns the client object in the database given a certain id. Raises
an error if that does not exist."""
return self.query(Client).filter(Client.id==id).one()
return self.query(Client).filter(Client.id == id).one()
def get_client_id_from_model_id(self, model_id):
"""Returns the client_id attached to the given model_id
......@@ -217,11 +218,10 @@ class Database(bob.db.base.SQLiteDatabase):
"""
return self.query(Model).filter(Model.name==model_id).first().client_id
return self.query(Model).filter(Model.name == model_id).first().client_id
def objects(self, protocol=None, purposes=None, model_ids=None, groups=None,
classes=None, finger_ids=None, session_ids=None):
classes=None, finger_ids=None, session_ids=None):
"""Returns a set of Files for the specific query by the user.
......@@ -277,10 +277,13 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocols = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(purposes, "purpose", self.purposes())
protocols = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(
purposes, "purpose", self.purposes())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
classes = self.check_parameters_for_validity(classes, "class", ('client', 'impostor'))
classes = self.check_parameters_for_validity(
classes, "class", ('client', 'impostor'))
from six import string_types
if model_ids is None:
......@@ -301,50 +304,60 @@ class Database(bob.db.base.SQLiteDatabase):
retval = []
if 'world' in groups:
q = self.query(File).join((Protocol, File.protocols_train)).\
filter(Protocol.name.in_(protocols))
if finger_ids: q = q.filter(File.finger_id.in_(finger_ids))
if session_ids: q = q.filter(File.session_id.in_(session_ids))
filter(Protocol.name.in_(protocols))
if finger_ids:
q = q.filter(File.finger_id.in_(finger_ids))
if session_ids:
q = q.filter(File.session_id.in_(session_ids))
q = q.order_by(File.client_id, File.finger_id, File.session_id)
retval += list(q)
if 'dev' in groups or 'eval' in groups:
sgroups = []
if 'dev' in groups: sgroups.append('dev')
if 'eval' in groups: sgroups.append('eval')
if 'dev' in groups:
sgroups.append('dev')
if 'eval' in groups:
sgroups.append('eval')
if 'enroll' in purposes:
q = self.query(File).join(Client).join((Model, File.models_enroll)).join((Protocol, Model.protocol)).\
filter(and_(Protocol.name.in_(protocols), Model.sgroup.in_(sgroups)))
filter(and_(Protocol.name.in_(protocols), Model.sgroup.in_(sgroups)))
if model_ids:
q = q.filter(Model.name.in_(model_ids))
if finger_ids: q = q.filter(File.finger_id.in_(finger_ids))
if session_ids: q = q.filter(File.session_id.in_(session_ids))
if finger_ids:
q = q.filter(File.finger_id.in_(finger_ids))
if session_ids:
q = q.filter(File.session_id.in_(session_ids))
q = q.order_by(File.client_id, File.finger_id, File.session_id)
retval += list(q)
if 'probe' in purposes:
if 'client' in classes:
q = self.query(File).join(Client).join((Model, File.models_probe)).join((Protocol, Model.protocol)).\
filter(and_(Protocol.name.in_(protocols), Model.sgroup.in_(sgroups), File.client_id == Model.client_id))
filter(and_(Protocol.name.in_(protocols), Model.sgroup.in_(
sgroups), File.client_id == Model.client_id))
if model_ids:
q = q.filter(Model.name.in_(model_ids))
if finger_ids: q = q.filter(File.finger_id.in_(finger_ids))
if session_ids: q = q.filter(File.session_id.in_(session_ids))
if finger_ids:
q = q.filter(File.finger_id.in_(finger_ids))
if session_ids:
q = q.filter(File.session_id.in_(session_ids))
q = q.order_by(File.client_id, File.finger_id, File.session_id)
retval += list(q)
if 'impostor' in classes:
q = self.query(File).join(Client).join((Model, File.models_probe)).join((Protocol, Model.protocol)).\
filter(and_(Protocol.name.in_(protocols), Model.sgroup.in_(sgroups), File.client_id != Model.client_id))
filter(and_(Protocol.name.in_(protocols), Model.sgroup.in_(
sgroups), File.client_id != Model.client_id))
if len(model_ids) != 0:
q = q.filter(Model.name.in_(model_ids))
if finger_ids: q = q.filter(File.finger_id.in_(finger_ids))
if session_ids: q = q.filter(File.session_id.in_(session_ids))
if finger_ids:
q = q.filter(File.finger_id.in_(finger_ids))
if session_ids:
q = q.filter(File.session_id.in_(session_ids))
q = q.order_by(File.client_id, File.finger_id, File.session_id)
retval += list(q)
return list(set(retval)) # To remove duplicates
return list(set(retval)) # To remove duplicates
def protocol_names(self):
"""Returns all registered protocol names"""
......@@ -353,25 +366,21 @@ class Database(bob.db.base.SQLiteDatabase):
retval = [str(k.name) for k in l]
return retval
def protocols(self):
"""Returns all registered protocols"""
return list(self.query(Protocol))
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
return self.query(Protocol).filter(Protocol.name==name).count() != 0
return self.query(Protocol).filter(Protocol.name == name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
return self.query(Protocol).filter(Protocol.name==name).one()
return self.query(Protocol).filter(Protocol.name == name).one()
def purposes(self):
return ('train', 'enroll', 'probe')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment