Skip to content
Snippets Groups Projects
Commit b6c09174 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed issue with join

parent d25aa24e
No related branches found
No related tags found
No related merge requests found
Pipeline #57503 failed
...@@ -20,60 +20,51 @@ class Database(bob.db.base.SQLiteDatabase): ...@@ -20,60 +20,51 @@ class Database(bob.db.base.SQLiteDatabase):
and for the data itself inside the database. and for the data itself inside the database.
""" """
def __init__(self, original_directory=None, original_extension=None): def __init__(self, original_directory=None, original_extension=None):
super(Database, self).__init__(SQLITE_FILE, File, super(Database, self).__init__(
original_directory, original_extension) SQLITE_FILE, File, original_directory, original_extension
)
def protocol_names(self): def protocol_names(self):
"""Returns a list of all supported protocols""" """Returns a list of all supported protocols"""
return tuple([k.name for k in self.query(Protocol).order_by(Protocol.name)]) return tuple([k.name for k in self.query(Protocol).order_by(Protocol.name)])
def purposes(self): def purposes(self):
"""Returns a list of all supported purposes""" """Returns a list of all supported purposes"""
return Subset.purpose_choices return Subset.purpose_choices
def groups(self): def groups(self):
"""Returns a list of all supported groups""" """Returns a list of all supported groups"""
return Subset.group_choices return Subset.group_choices
def genders(self): def genders(self):
"""Returns a list of all supported gender values""" """Returns a list of all supported gender values"""
return Client.gender_choices return Client.gender_choices
def finger_names(self): def finger_names(self):
"""Returns a list of all supported finger name values""" """Returns a list of all supported finger name values"""
return Finger.name_choices return Finger.name_choices
def sessions(self): def sessions(self):
"""Returns a list of all supported session values""" """Returns a list of all supported session values"""
return File.session_choices return File.session_choices
def file_from_model_id(self, model_id): def file_from_model_id(self, model_id):
"""Returns the file in the database given a ``model_id``""" """Returns the file in the database given a ``model_id``"""
return self.query(File).filter(File.model_id == model_id).one() return self.query(File).filter(File.model_id == model_id).one()
def finger_name_from_model_id(self, model_id): def finger_name_from_model_id(self, model_id):
"""Returns the unique finger name in the database given a ``model_id``""" """Returns the unique finger name in the database given a ``model_id``"""
return self.file_from_model_id(model_id).unique_finger_name return self.file_from_model_id(model_id).unique_finger_name
def model_ids(self, protocol=None, groups=None): def model_ids(self, protocol=None, groups=None):
"""Returns a set of models for a given protocol/group """Returns a set of models for a given protocol/group
...@@ -99,13 +90,13 @@ class Database(bob.db.base.SQLiteDatabase): ...@@ -99,13 +90,13 @@ class Database(bob.db.base.SQLiteDatabase):
protocols = None protocols = None
if protocol: if protocol:
valid_protocols = self.protocol_names() valid_protocols = self.protocol_names()
protocols = self.check_parameters_for_validity(protocol, "protocol", protocols = self.check_parameters_for_validity(
valid_protocols) protocol, "protocol", valid_protocols
)
if groups: if groups:
valid_groups = self.groups() valid_groups = self.groups()
groups = self.check_parameters_for_validity(groups, "group", groups = self.check_parameters_for_validity(groups, "group", valid_groups)
valid_groups)
retval = self.query(File) retval = self.query(File)
...@@ -122,18 +113,25 @@ class Database(bob.db.base.SQLiteDatabase): ...@@ -122,18 +113,25 @@ class Database(bob.db.base.SQLiteDatabase):
if groups: if groups:
subfilters.append(Subset.group.in_(groups)) subfilters.append(Subset.group.in_(groups))
subfilters.append(Subset.purpose == 'enroll') subfilters.append(Subset.purpose == "enroll")
subsets = subquery.filter(*subfilters) subsets = subquery.filter(*subfilters)
filters.append(File.subsets.any(Subset.id.in_([k.id for k in subsets]))) filters.append(File.subsets.any(Subset.id.in_([k.id for k in subsets])))
retval = retval.join(*joins).filter(*filters).distinct().order_by('id') retval = retval.join(*joins).filter(*filters).distinct().order_by("id")
return sorted(set([k.model_id for k in retval.distinct()])) return sorted(set([k.model_id for k in retval.distinct()]))
def objects(
def objects(self, protocol=None, groups=None, purposes=None, self,
model_ids=None, genders=None, finger_names=None, sessions=None): protocol=None,
groups=None,
purposes=None,
model_ids=None,
genders=None,
finger_names=None,
sessions=None,
):
"""Returns objects filtered by criteria """Returns objects filtered by criteria
...@@ -169,45 +167,51 @@ class Database(bob.db.base.SQLiteDatabase): ...@@ -169,45 +167,51 @@ class Database(bob.db.base.SQLiteDatabase):
""" """
protocols = None protocols = None
if protocol: if protocol:
valid_protocols = self.protocol_names() valid_protocols = self.protocol_names()
protocols = self.check_parameters_for_validity(protocol, "protocol", protocols = self.check_parameters_for_validity(
valid_protocols) protocol, "protocol", valid_protocols
)
if groups: if groups:
valid_groups = self.groups() valid_groups = self.groups()
groups = self.check_parameters_for_validity( groups = self.check_parameters_for_validity(groups, "group", valid_groups)
groups, "group", valid_groups)
if purposes: if purposes:
valid_purposes = self.purposes() valid_purposes = self.purposes()
purposes = self.check_parameters_for_validity(purposes, "purpose", purposes = self.check_parameters_for_validity(
valid_purposes) purposes, "purpose", valid_purposes
)
# if only asking for 'probes', then ignore model_ids as all of our # if only asking for 'probes', then ignore model_ids as all of our
# protocols do a full probe-model scan # protocols do a full probe-model scan
if purposes and len(purposes) == 1 and 'probe' in purposes: if purposes and len(purposes) == 1 and "probe" in purposes:
model_ids = None model_ids = None
if model_ids: if model_ids:
valid_model_ids = self.model_ids(protocol, groups) valid_model_ids = self.model_ids(protocol, groups)
model_ids = self.check_parameters_for_validity(model_ids, "model_ids", model_ids = self.check_parameters_for_validity(
valid_model_ids) model_ids, "model_ids", valid_model_ids
)
if genders: if genders:
valid_genders = self.genders() valid_genders = self.genders()
genders = self.check_parameters_for_validity(genders, "genders", genders = self.check_parameters_for_validity(
valid_genders) genders, "genders", valid_genders
)
if finger_names: if finger_names:
valid_finger_names = self.finger_names() valid_finger_names = self.finger_names()
finger_names = self.check_parameters_for_validity(finger_names, finger_names = self.check_parameters_for_validity(
"finger_names", valid_finger_names) finger_names, "finger_names", valid_finger_names
)
if sessions: if sessions:
valid_sessions = self.sessions() valid_sessions = self.sessions()
sessions = self.check_parameters_for_validity(sessions, "sessions", sessions = self.check_parameters_for_validity(
valid_sessions) sessions, "sessions", valid_sessions
)
retval = self.query(File) retval = self.query(File)
...@@ -232,12 +236,16 @@ class Database(bob.db.base.SQLiteDatabase): ...@@ -232,12 +236,16 @@ class Database(bob.db.base.SQLiteDatabase):
filters.append(File.subsets.any(Subset.id.in_([k.id for k in subsets]))) filters.append(File.subsets.any(Subset.id.in_([k.id for k in subsets])))
if genders or finger_names: #import ipdb
#ipdb.set_trace()
joins.append(Finger) joins.append(Finger)
if genders or finger_names:
if genders: if genders:
fingers = self.query(Finger).join( fingers = (
Client).filter(Client.gender.in_(genders)) self.query(Finger).join(Client).filter(Client.gender.in_(genders))
)
filters.append(Finger.id.in_([k.id for k in fingers])) filters.append(Finger.id.in_([k.id for k in fingers]))
if finger_names: if finger_names:
...@@ -251,11 +259,15 @@ class Database(bob.db.base.SQLiteDatabase): ...@@ -251,11 +259,15 @@ class Database(bob.db.base.SQLiteDatabase):
# special case for 1vsall protocol: if only one model id given, returns # special case for 1vsall protocol: if only one model id given, returns
# all but the sample for the model id in the list # all but the sample for the model id in the list
if model_ids and len(model_ids) == 1 and \ if (
protocols and len(protocols) == 1 and \ model_ids
protocols[0] == '1vsall': and len(model_ids) == 1
and protocols
and len(protocols) == 1
and protocols[0] == "1vsall"
):
filters.append(~self.file_from_model_id(model_ids[0])) filters.append(~self.file_from_model_id(model_ids[0]))
retval = retval.join(*joins).filter(*filters).distinct().order_by('id') retval = retval.join(*joins).filter(*filters).distinct().order_by("id")
return list(retval) return list(retval)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment