Commit c8890f7c authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

last sql version commit

parent 73e151dc
......@@ -48,8 +48,9 @@ class Client(Base):
self.institute = institute
def __repr__(self):
return "Client(id={}, orig_id={}, group={}, institute={})".format(
self.id, self.sgroup, self.gender, self.institute)
return "Client(id={}, orig_id={}, group={}, gender={}, institute={})"\
.format(self.id, self.orig_id, self.sgroup, self.gender,
self.institute)
class File(Base, bob.db.base.File):
......
......@@ -25,7 +25,10 @@ class Database(bob.db.base.SQLiteDatabase):
annotation_directory=None, annotation_extension='.pos'):
# call base class constructors to open a session to the database
super(Database, self).__init__(
SQLITE_FILE, File, original_directory, original_extension)
SQLITE_FILE, File)
self.original_directory = original_directory
self.original_extension = original_extension
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
......@@ -102,6 +105,8 @@ class Database(bob.db.base.SQLiteDatabase):
groups, "group", self.groups(), self.groups())
gender = self.check_parameters_for_validity(
gender, "gender", self.genders(), [])
institute = self.check_parameters_for_validity(
institute, "institute", self.institutes(), [])
# List of the clients
retval = []
......@@ -109,6 +114,8 @@ class Database(bob.db.base.SQLiteDatabase):
q = self.query(Client).filter(Client.sgroup == 'world')
if gender:
q = q.filter(Client.gender.in_(gender))
if institute:
q = q.filter(Client.institute.in_(institute))
q = q.order_by(Client.id)
retval += list(q)
......@@ -118,14 +125,10 @@ class Database(bob.db.base.SQLiteDatabase):
if 'eval' in groups:
dev_eval.append('eval')
if dev_eval:
protocol_gender = None
if protocol:
q = self.query(Protocol).filter(
Protocol.name.in_(protocol)).one()
protocol_gender = [q.gender]
q = self.query(Client).filter(Client.sgroup.in_(dev_eval))
if protocol_gender:
q = q.filter(Client.gender.in_(protocol_gender))
if gender:
q = q.filter(Client.gender.in_(gender))
q = q.order_by(Client.id)
......@@ -134,7 +137,8 @@ class Database(bob.db.base.SQLiteDatabase):
return retval
def has_client_id(self, id):
"""Returns True if we have a client with a certain integer identifier"""
"""Returns True if we have a client with a certain integer
identifier"""
return self.query(Client).filter(Client.id == id).count() != 0
......@@ -171,7 +175,8 @@ class Database(bob.db.base.SQLiteDatabase):
return self.clients(protocol, groups, subworld, gender)
def model_ids(self, protocol=None, groups=None, subworld=None, gender=None):
def model_ids_with_protocol(self, groups=None, protocol=None, gender=None,
institute=None):
"""Returns a set of models ids for the specific query by the user.
Keyword Parameters:
......@@ -196,7 +201,8 @@ class Database(bob.db.base.SQLiteDatabase):
Returns: A list containing the ids of all models belonging to the given group.
"""
return [client.id for client in self.clients(protocol, groups, subworld, gender)]
return [client.id for client in self.clients(
protocol, groups, gender, institute)]
def get_client_id_from_model_id(self, model_id, **kwargs):
"""Returns the client_id attached to the given model_id
......@@ -210,8 +216,8 @@ class Database(bob.db.base.SQLiteDatabase):
"""
return model_id
def objects(self, protocol=None, purposes=None, model_ids=None,
groups=None, gender=None, device=None):
def objects(self, groups=None, protocol=None, purposes=None,
model_ids=None):
"""Returns a set of Files for the specific query by the user.
Keyword Parameters:
......@@ -252,10 +258,6 @@ class Database(bob.db.base.SQLiteDatabase):
purposes, "purpose", self.purposes())
groups = self.check_parameters_for_validity(
groups, "group", self.groups())
gender = self.check_parameters_for_validity(
gender, "gender", self.genders(), [])
device = self.check_parameters_for_validity(
device, "device", File.device_choices, [])
import collections
if(model_ids is None):
......@@ -266,13 +268,10 @@ class Database(bob.db.base.SQLiteDatabase):
# Now query the database
retval = []
if 'world' in groups and 'train' in purposes:
q = self.query(File).join(Client).filter(Client.sgroup == 'world').join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol),
ProtocolPurpose.sgroup == 'world'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
q = q.filter(File.device.in_(device))
q = self.query(File).join(Client).filter(Client.sgroup == 'world')\
.join((ProtocolPurpose, File.protocol_purposes)).\
join(Protocol).filter(and_(Protocol.name.in_(protocol),
ProtocolPurpose.sgroup == 'world'))
if model_ids:
q = q.filter(File.client_id.in_(model_ids))
q = q.order_by(File.client_id, File.session, File.device)
......@@ -283,10 +282,6 @@ class Database(bob.db.base.SQLiteDatabase):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'enroll'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
q = q.filter(File.device.in_(device))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session, File.device)
......@@ -296,10 +291,6 @@ class Database(bob.db.base.SQLiteDatabase):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'probe'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
q = q.filter(File.device.in_(device))
if model_ids:
q = q.filter(Client.id.in_(model_ids))
q = q.order_by(File.client_id, File.session, File.device)
......@@ -308,10 +299,6 @@ class Database(bob.db.base.SQLiteDatabase):
q = self.query(File).join(Client).join((ProtocolPurpose, File.protocol_purposes)).join(Protocol).\
filter(and_(Protocol.name.in_(protocol), ProtocolPurpose.sgroup.in_(
groups), ProtocolPurpose.purpose == 'probe'))
if gender:
q = q.filter(Client.gender.in_(gender))
if device:
q = q.filter(File.device.in_(device))
if len(model_ids) == 1:
q = q.filter(not_(File.client_id.in_(model_ids)))
q = q.order_by(File.client_id, File.session, File.device)
......
......@@ -21,10 +21,10 @@ def test_idiap0_audio():
assert set(f.client.institute for f in files) == set(['IDIAP'])
assert all(f.client.orig_id < 25 for f in files)
files = db.objects(protocol=protocol, groups='dev', purposes='enroll')
assert len(files) == 15 * 8 * 1 * 1, len(files)
assert len(files) == 15 * 8 * 2 * 1, len(files)
assert len(set(f.client.id for f in files)) == 15
assert len(set(f.nrecording for f in files)) == 8
assert len(set(f.device for f in files)) == 1
assert len(set(f.device for f in files)) == 2
assert all(f.session == 1 for f in files)
assert set(f.client.institute for f in files) == set(['IDIAP'])
assert all(f.client.orig_id >= 25 and f.client.orig_id < 41 for f in files)
......@@ -38,10 +38,10 @@ def test_idiap0_audio():
assert set(f.client.institute for f in files) == set(['IDIAP'])
assert all(f.client.orig_id >= 25 and f.client.orig_id < 41 for f in files)
files = db.objects(protocol=protocol, groups='eval', purposes='enroll')
assert len(files) == 15 * 8 * 1 * 1, len(files)
assert len(files) == 15 * 8 * 2 * 1, len(files)
assert len(set(f.client.id for f in files)) == 15
assert len(set(f.nrecording for f in files)) == 8
assert len(set(f.device for f in files)) == 1
assert len(set(f.device for f in files)) == 2
assert all(f.session == 1 for f in files)
assert set(f.client.institute for f in files) == set(['IDIAP'])
assert all(f.client.orig_id >= 41 and f.client.orig_id < 61 for f in files)
......@@ -54,3 +54,12 @@ def test_idiap0_audio():
assert all(f.session > 1 for f in files)
assert set(f.client.institute for f in files) == set(['IDIAP'])
assert all(f.client.orig_id >= 41 and f.client.orig_id < 61 for f in files)
model_ids = db.model_ids_with_protocol(groups='world', protocol=protocol)
assert len(model_ids) == 20, len(model_ids)
model_ids = db.model_ids_with_protocol(groups='dev', protocol=protocol)
assert len(model_ids) == 15, len(model_ids)
model_ids = db.model_ids_with_protocol(groups='eval', protocol=protocol)
assert len(model_ids) == 15, len(model_ids)
assert db.annotations(files[0]) is None
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment