Commit 1c1b3b85 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

The SQLiteDatabase now accepts original_directory and original_extension

parent b5509085
Pipeline #9176 passed with stages
in 26 minutes and 28 seconds
......@@ -29,6 +29,7 @@ import bob.db.base
SQLITE_FILE = Interface().files()[0]
class Database(bob.db.base.SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
......@@ -36,25 +37,28 @@ class Database(bob.db.base.SQLiteDatabase):
and for the data itself inside the database.
"""
def __init__(self, original_directory = None, original_extension = None):
def __init__(self, original_directory=None, original_extension=None):
# call base class constructors
# copy original file name and extension
super(Database, self).__init__(SQLITE_FILE, File)
self.original_directory = original_directory
self.original_extension = original_extension
super(Database, self).__init__(SQLITE_FILE, File,
original_directory, original_extension)
def __group_replace_alias__(self, l):
"""Replace 'dev' by 'g1' and 'eval' by 'g2' in a list of groups, and
returns the new list"""
if not l: return l
elif isinstance(l, six.string_types): return self.__group_replace_alias__((l,))
if not l:
return l
elif isinstance(l, six.string_types):
return self.__group_replace_alias__((l,))
l2 = []
for val in l:
if(val == 'dev'): l2.append('g1')
elif(val == 'eval'): l2.append('g2')
else: l2.append(val)
if(val == 'dev'):
l2.append('g1')
elif(val == 'eval'):
l2.append('g2')
else:
l2.append(val)
return tuple(l2)
def groups(self, protocol=None):
......@@ -93,7 +97,7 @@ class Database(bob.db.base.SQLiteDatabase):
def has_subworld(self, name):
"""Tells if a certain subworld is available"""
return self.query(Subworld).filter(Subworld.name==name).count() != 0
return self.query(Subworld).filter(Subworld.name == name).count() != 0
def clients(self, protocol=None, groups=None, genders=None, languages=None, subworld=None):
"""Returns a set of clients for the specific query by the user.
......@@ -123,29 +127,34 @@ class Database(bob.db.base.SQLiteDatabase):
"""
groups = self.__group_replace_alias__(groups)
groups = self.check_parameters_for_validity(groups, "group", self.client_groups())
genders = self.check_parameters_for_validity(genders, "gender", self.genders())
languages = self.check_parameters_for_validity(languages, "language", self.languages())
subworld = self.check_parameters_for_validity(subworld, "subworld", self.subworld_names())
groups = self.check_parameters_for_validity(
groups, "group", self.client_groups())
genders = self.check_parameters_for_validity(
genders, "gender", self.genders())
languages = self.check_parameters_for_validity(
languages, "language", self.languages())
subworld = self.check_parameters_for_validity(
subworld, "subworld", self.subworld_names())
retval = []
# List of the clients
if "world" in groups:
if len(subworld)==1:
q = self.query(Client).join((Subworld,Client.subworld)).filter(Subworld.name.in_(subworld))
if len(subworld) == 1:
q = self.query(Client).join((Subworld, Client.subworld)
).filter(Subworld.name.in_(subworld))
else:
q = self.query(Client).filter(Client.sgroup == 'world')
q = q.filter(Client.gender.in_(genders)).\
filter(Client.language.in_(languages)).\
filter(Client.language.in_(languages)).\
order_by(Client.id)
retval += list(q)
if 'g1' in groups or 'g2' in groups:
q = self.query(Client).filter(Client.sgroup != 'world').\
filter(Client.sgroup.in_(groups)).\
filter(Client.gender.in_(genders)).\
filter(Client.language.in_(languages)).\
order_by(Client.id)
filter(Client.sgroup.in_(groups)).\
filter(Client.gender.in_(genders)).\
filter(Client.language.in_(languages)).\
order_by(Client.id)
retval += list(q)
return retval
......@@ -200,7 +209,6 @@ class Database(bob.db.base.SQLiteDatabase):
zgroups.append('g1')
return self.clients(protocol, zgroups)
def models(self, protocol=None, groups=None):
"""Returns a set of models for the specific query by the user.
......@@ -270,13 +278,13 @@ class Database(bob.db.base.SQLiteDatabase):
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
return self.query(Client).filter(Client.id==id).count() != 0
return self.query(Client).filter(Client.id == id).count() != 0
def client(self, id):
"""Returns the client object in the database given a certain id. Raises
an error if that does not exist."""
return self.query(Client).filter(Client.id==id).one()
return self.query(Client).filter(Client.id == id).one()
def get_client_id_from_model_id(self, model_id, **kwargs):
"""Returns the client_id attached to the given model_id
......@@ -303,7 +311,7 @@ class Database(bob.db.base.SQLiteDatabase):
return tmodel_id
def objects(self, protocol=None, purposes=None, model_ids=None, groups=None,
classes=None, languages=None, subworld=None):
classes=None, languages=None, subworld=None):
"""Returns a set of Files for the specific query by the user.
Keyword Parameters:
......@@ -346,62 +354,76 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list of files which have the given properties.
"""
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(purposes, "purpose", self.purposes())
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(
purposes, "purpose", self.purposes())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
languages = self.check_parameters_for_validity(languages, "language", self.languages())
classes = self.check_parameters_for_validity(classes, "class", ('client', 'impostor'))
subworld = self.check_parameters_for_validity(subworld, "subworld", self.subworld_names())
languages = self.check_parameters_for_validity(
languages, "language", self.languages())
classes = self.check_parameters_for_validity(
classes, "class", ('client', 'impostor'))
subworld = self.check_parameters_for_validity(
subworld, "subworld", self.subworld_names())
import collections
if(model_ids is None):
model_ids = ()
elif(not isinstance(model_ids,collections.Iterable)):
elif(not isinstance(model_ids, collections.Iterable)):
model_ids = (model_ids,)
# Now query the database
retval = []
if 'world' in groups:
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol)
q = self.query(File).join(Client).join(
(ProtocolPurpose, File.protocolPurposes)).join(Protocol)
if len(subworld) == 1:
q = q.join((Subworld,Client.subworld)).filter(Subworld.name.in_(subworld))
q = q.join((Subworld, Client.subworld)).filter(
Subworld.name.in_(subworld))
q = q.filter(Client.sgroup == 'world').\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world')).\
filter(Client.language.in_(languages))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world')).\
filter(Client.language.in_(languages))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.claimed_id, File.shot_id)
q = q.order_by(File.client_id, File.session_id,
File.claimed_id, File.shot_id)
retval += list(q)
if ('dev' in groups or 'eval' in groups):
if('enroll' in purposes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'enroll'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'enroll'))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.claimed_id, File.shot_id)
q = q.order_by(File.client_id, File.session_id,
File.claimed_id, File.shot_id)
retval += list(q)
if('probe' in purposes):
if('client' in classes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).\
filter(File.client_id == File.claimed_id).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
filter(File.client_id == File.claimed_id).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'probe'))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.claimed_id, File.shot_id)
q = q.order_by(File.client_id, File.session_id,
File.claimed_id, File.shot_id)
retval += list(q)
if('impostor' in classes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocolPurposes)).join(Protocol).\
filter(File.client_id != File.claimed_id).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
filter(File.client_id != File.claimed_id).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'probe'))
if model_ids:
q = q.filter(File.claimed_id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.claimed_id, File.shot_id)
q = q.order_by(File.client_id, File.session_id,
File.claimed_id, File.shot_id)
retval += list(q)
return list(set(retval)) # To remove duplicates
return list(set(retval)) # To remove duplicates
def tobjects(self, protocol=None, model_ids=None, groups=None, languages=None):
"""Returns a set of Files for enrolling T-norm models for score
......@@ -429,7 +451,8 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list of Files which have the given properties.
"""
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
groups = self.check_parameters_for_validity(
groups, "group", ('dev', 'eval'))
# g2 clients are used for normalizing g1 ones, etc.
tgroups = []
if 'dev' in groups:
......@@ -463,7 +486,8 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list of Files which have the given properties.
"""
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
groups = self.check_parameters_for_validity(
groups, "group", ('dev', 'eval'))
# g2 clients are used for normalizing g1 ones, etc.
zgroups = []
if 'dev' in groups:
......@@ -484,10 +508,10 @@ class Database(bob.db.base.SQLiteDatabase):
"""
self.assert_validity()
# return the annotations as returned by the call function of the Annotation object
# return the annotations as returned by the call function of the
# Annotation object
return file.annotation()
def protocol_names(self):
"""Returns all registered protocol names"""
......@@ -503,13 +527,13 @@ class Database(bob.db.base.SQLiteDatabase):
def has_protocol(self, name):
"""Tells if a certain protocol is available"""
return self.query(Protocol).filter(Protocol.name==name).count() != 0
return self.query(Protocol).filter(Protocol.name == name).count() != 0
def protocol(self, name):
"""Returns the protocol object in the database given a certain name. Raises
an error if that does not exist."""
return self.query(Protocol).filter(Protocol.name==name).one()
return self.query(Protocol).filter(Protocol.name == name).one()
def protocol_purposes(self):
"""Returns all registered protocol purposes"""
......@@ -521,18 +545,17 @@ class Database(bob.db.base.SQLiteDatabase):
return ProtocolPurpose.purpose_choices
def t_model_ids(self, protocol, groups = 'dev', **kwargs):
def t_model_ids(self, protocol, groups='dev', **kwargs):
"""Returns the list of model ids used for T-Norm of the given protocol for the given group that satisfy your query.
For possible keyword arguments, please check the :py:meth:`tmodel_ids` function."""
return self.uniquify(self.tmodel_ids(protocol=protocol, groups=groups, **kwargs))
def t_enroll_files(self, protocol, model_id, groups = 'dev', **kwargs):
def t_enroll_files(self, protocol, model_id, groups='dev', **kwargs):
"""Returns the list of T-Norm model enrollment File objects from the given model id of the given protocol for the given group that satisfy your query.
For possible keyword arguments, please check the :py:meth:`tobjects` function."""
return self.uniquify(self.tobjects(protocol=protocol, groups=groups, model_ids=(model_id,), **kwargs))
def z_probe_files(self, protocol, groups = 'dev', **kwargs):
def z_probe_files(self, protocol, groups='dev', **kwargs):
"""Returns the list of Z-Norm probe File objects to probe the model with the given model id of the given protocol for the given group that satisfy your query.
For possible keyword arguments, please check the :py:meth:`zobjects` function."""
return self.uniquify(self.zobjects(protocol=protocol, groups=groups, **kwargs))
py:class bob.db.base.file.File
py:class bob.db.base.database.SQLiteDatabase
py:class sqlalchemy.ext.declarative.api.Base
py:exc ValueError
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment