Commit 68235666 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

The SQLiteDatabase now accepts original_directory and original_extension

parent 12b71979
Pipeline #9190 passed with stages
in 28 minutes and 8 seconds
......@@ -39,7 +39,8 @@ class Database(bob.db.base.SQLiteDatabase):
def __init__(self, original_directory=None, original_extension=None, annotation_directory=None, annotation_extension='.pos'):
# call base class constructors to open a session to the database
super(Database, self).__init__(SQLITE_FILE, File, original_directory, original_extension)
super(Database, self).__init__(SQLITE_FILE, File,
original_directory, original_extension)
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
......@@ -120,17 +121,22 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocol = self._replace_protocols_alias(protocol)
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names(), [])
groups = self.check_parameters_for_validity(groups, "group", self.groups(), self.groups())
subworld = self.check_parameters_for_validity(subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(gender, "gender", self.genders(), [])
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names(), [])
groups = self.check_parameters_for_validity(
groups, "group", self.groups(), self.groups())
subworld = self.check_parameters_for_validity(
subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(
gender, "gender", self.genders(), [])
# List of the clients
retval = []
if 'world' in groups:
q = self.query(Client).filter(Client.sgroup == 'world')
if subworld:
q = q.join((Subworld, Client.subworld)).filter(Subworld.name.in_(subworld))
q = q.join((Subworld, Client.subworld)).filter(
Subworld.name.in_(subworld))
if gender:
q = q.filter(Client.gender.in_(gender))
q = q.order_by(Client.id)
......@@ -301,14 +307,19 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocol = self._replace_protocols_alias(protocol)
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
subworld = self.check_parameters_for_validity(subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(gender, "gender", self.genders(), [])
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
subworld = self.check_parameters_for_validity(
subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(
gender, "gender", self.genders(), [])
# List of the clients
q = self.query(TModel).join(Client).join(Protocol).filter(Protocol.name.in_(protocol))
q = self.query(TModel).join(Client).join(
Protocol).filter(Protocol.name.in_(protocol))
if subworld:
q = q.join((Subworld, Client.subworld)).filter(Subworld.name.in_(subworld))
q = q.join((Subworld, Client.subworld)).filter(
Subworld.name.in_(subworld))
if gender:
q = q.filter(Client.gender.in_(gender))
q = q.order_by(TModel.id)
......@@ -399,13 +410,19 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocol = self._replace_protocols_alias(protocol)
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(purposes, "purpose", self.purposes())
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
purposes = self.check_parameters_for_validity(
purposes, "purpose", self.purposes())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
classes = self.check_parameters_for_validity(classes, "class", ('client', 'impostor'))
subworld = self.check_parameters_for_validity(subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(gender, "gender", self.genders(), [])
device = self.check_parameters_for_validity(device, "device", File.device_choices, [])
classes = self.check_parameters_for_validity(
classes, "class", ('client', 'impostor'))
subworld = self.check_parameters_for_validity(
subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(
gender, "gender", self.genders(), [])
device = self.check_parameters_for_validity(
device, "device", File.device_choices, [])
import collections
if(model_ids is None):
......@@ -417,54 +434,63 @@ class Database(bob.db.base.SQLiteDatabase):
retval = []
if 'world' in groups and 'train' in purposes:
q = self.query(File).join(Client).filter(Client.sgroup == 'world').join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world'))
filter(and_(Protocol.name.in_(protocol),
ProtocolPurpose.sgroup == 'world'))
if subworld:
q = q.join((Subworld, File.subworld)).filter(Subworld.name.in_(subworld))
q = q.join((Subworld, File.subworld)).filter(
Subworld.name.in_(subworld))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
q = q.filter(File.device.in_(device))
if model_ids:
q = q.filter(File.client_id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.speech_type, File.shot_id, File.device)
q = q.order_by(File.client_id, File.session_id,
File.speech_type, File.shot_id, File.device)
retval += list(q)
if ('dev' in groups or 'eval' in groups):
if('enroll' in purposes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'enroll'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'enroll'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
q = q.filter(File.device.in_(device))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.speech_type, File.shot_id, File.device)
q = q.order_by(File.client_id, File.session_id,
File.speech_type, File.shot_id, File.device)
retval += list(q)
if('probe' in purposes):
if('client' in classes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'probe'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
q = q.filter(File.device.in_(device))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.speech_type, File.shot_id, File.device)
q = q.order_by(File.client_id, File.session_id,
File.speech_type, File.shot_id, File.device)
retval += list(q)
if('impostor' in classes):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(groups), ProtocolPurpose.purpose == 'probe'))
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'probe'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
q = q.filter(File.device.in_(device))
if len(model_ids) == 1:
q = q.filter(not_(File.client_id.in_(model_ids)))
q = q.order_by(File.client_id, File.session_id, File.speech_type, File.shot_id, File.device)
q = q.order_by(File.client_id, File.session_id,
File.speech_type, File.shot_id, File.device)
retval += list(q)
return list(set(retval)) # To remove duplicates
......@@ -507,9 +533,12 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocol = self._replace_protocols_alias(protocol)
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
subworld = self.check_parameters_for_validity(subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(gender, "gender", self.genders(), [])
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
subworld = self.check_parameters_for_validity(
subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(
gender, "gender", self.genders(), [])
if(model_ids is None):
model_ids = ()
......@@ -517,10 +546,12 @@ class Database(bob.db.base.SQLiteDatabase):
model_ids = (model_ids,)
# Now query the database
q = self.query(File, Protocol).filter(Protocol.name.in_(protocol)).join(Client)
q = self.query(File, Protocol).filter(
Protocol.name.in_(protocol)).join(Client)
if subworld:
q = q.join((Subworld, File.subworld)).filter(Subworld.name.in_(subworld))
q = q.join((TModel, File.tmodels)).filter(TModel.protocol_id == Protocol.id)
q = q.join((TModel, File.tmodels)).filter(
TModel.protocol_id == Protocol.id)
if model_ids:
q = q.filter(TModel.mid.in_(model_ids))
if gender:
......@@ -529,7 +560,8 @@ class Database(bob.db.base.SQLiteDatabase):
q = q.filter(File.speech_type.in_(speech_type))
if device:
q = q.filter(File.device.in_(device))
q = q.order_by(File.client_id, File.session_id, File.speech_type, File.shot_id, File.device)
q = q.order_by(File.client_id, File.session_id,
File.speech_type, File.shot_id, File.device)
retval = [v[0] for v in q]
return list(retval)
......@@ -570,12 +602,17 @@ class Database(bob.db.base.SQLiteDatabase):
"""
protocol = self._replace_protocols_alias(protocol)
protocol = self.check_parameters_for_validity(protocol, "protocol", self.protocol_names())
protocol = self.check_parameters_for_validity(
protocol, "protocol", self.protocol_names())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
subworld = self.check_parameters_for_validity(subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(gender, "gender", self.genders(), [])
speech_type = self.check_parameters_for_validity(speech_type, "speech_type", File.speech_type_choices)
device = self.check_parameters_for_validity(device, "device", File.device_choices)
subworld = self.check_parameters_for_validity(
subworld, "subworld", self.subworld_names(), [])
gender = self.check_parameters_for_validity(
gender, "gender", self.genders(), [])
speech_type = self.check_parameters_for_validity(
speech_type, "speech_type", File.speech_type_choices)
device = self.check_parameters_for_validity(
device, "device", File.device_choices)
import collections
if(model_ids is None):
......@@ -585,7 +622,8 @@ class Database(bob.db.base.SQLiteDatabase):
# Now query the database
q = self.query(File).join(Client).filter(Client.sgroup == 'world').join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup == 'world'))
filter(and_(Protocol.name.in_(protocol),
ProtocolPurpose.sgroup == 'world'))
if subworld:
q = q.join((Subworld, File.subworld)).filter(Subworld.name.in_(subworld))
if gender:
......@@ -596,7 +634,8 @@ class Database(bob.db.base.SQLiteDatabase):
q = q.filter(File.device.in_(device))
if model_ids:
q = q.filter(File.client_id.in_(model_ids))
q = q.order_by(File.client_id, File.session_id, File.speech_type, File.shot_id, File.device)
q = q.order_by(File.client_id, File.session_id,
File.speech_type, File.shot_id, File.device)
return list(q)
def annotations(self, file):
......@@ -616,7 +655,8 @@ class Database(bob.db.base.SQLiteDatabase):
return None
self.assert_validity()
annotation_file = file.make_path(self.annotation_directory, self.annotation_extension)
annotation_file = file.make_path(
self.annotation_directory, self.annotation_extension)
# return the annotations as read from file
return bob.db.base.read_annotation_file(annotation_file, 'eyecenter')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment