Commit 95b84a9a authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed bugs

parent 77f4d48e
......@@ -192,14 +192,14 @@ class Database(bob.db.verification.utils.Database):
objects.extend(self.objects(groups=g, protocol=protocol))
else:
self._load_data(protocol, "dev", "")
objects.extend([o for t in self.memory_db[protocol]['comparison-templates'] for o in self.memory_db[protocol]['comparison-templates'][t] ])
objects.extend([o for t in self.memory_db[protocol]['comparison-templates'] for o in self.memory_db[protocol]['comparison-templates'][t].files ])
ids = list(set([o.client_id for o in objects ]))
return ids
def model_ids(self, groups=None, protocol='search_split1'):
def model_ids(self, groups=None, protocol='search_split1', purposes='enroll', model_ids=None):
"""Returns a list of model ids for the specific query by the user.
Keyword Parameters:
......@@ -210,24 +210,36 @@ class Database(bob.db.verification.utils.Database):
protocol
One of the available protocol names, see :py:meth:`protocol_names`.
purposes
Returns: A list containing all the model ids for the given protocol.
"""
protocol = self.check_parameter_for_validity(protocol, "protocol", self.protocol_names())
groups = self.check_parameters_for_validity(groups, "group", self.groups())
#Just filling up the memory_db
purposes = self.check_parameters_for_validity(purposes, "purpose", ["enroll","probe"])
ids = []
if "search" in protocol:
self._load_data(protocol, "dev", "enroll")
self._load_data(protocol, "dev", "probe")
ids = [t for t in self.memory_db[protocol]['enroll']]
ids.extend([t for t in self.memory_db[protocol]['probe']])
for p in purposes:
self._load_data(protocol, "dev", p)
ids.extend([t for t in self.memory_db[protocol][p]])
else:
self._load_data(protocol, "dev", "")
ids = [t for t in self.memory_db[protocol]['comparison-templates']]
for p in purposes:
if p == "enroll":
for c in self.memory_db[protocol]['comparisons']:
ids.append(c)
else:
if(model_ids is None):
for c in self.memory_db[protocol]['comparisons']:
for probe in self.memory_db[protocol]['comparisons'][c]:
ids.append(probe)
else:
for c in model_ids:
for probe in self.memory_db[protocol]['comparisons'][c]:
ids.append(probe)
return ids
......@@ -280,7 +292,7 @@ class Database(bob.db.verification.utils.Database):
objects = []
if 'world' in groups:
self._load_data(protocol, "world", "train")
objects.extend([o for t in self.memory_db[protocol]['train'] for o in self.memory_db[protocol]['train'][t] ])
objects.extend([o for t in self.memory_db[protocol]['train'] for o in self.memory_db[protocol]['train'][t].files ])
if 'dev' in groups:
......@@ -290,16 +302,16 @@ class Database(bob.db.verification.utils.Database):
self._load_data(protocol, "dev", "enroll")
if(model_ids is None):
objects.extend([o for t in self.memory_db[protocol]['enroll'] for o in self.memory_db[protocol]['enroll'][t]])
objects.extend([o for t in self.memory_db[protocol]['enroll'] for o in self.memory_db[protocol]['enroll'][t].files])
else:
objects.extend([o for t in model_ids for o in self.memory_db[protocol]['enroll'][t]])
objects.extend([o for t in model_ids for o in self.memory_db[protocol]['enroll'][t].files])
if 'probe' in purposes:
self._load_data(protocol, "dev", "probe")
#The probes for the search are the same for all users
objects.extend([o for t in self.memory_db[protocol]['probe'] for o in self.memory_db[protocol]['probe'][t]])
objects.extend([o for t in self.memory_db[protocol]['probe'] for o in self.memory_db[protocol]['probe'][t].files])
#Dealing with comparisons
......@@ -311,19 +323,22 @@ class Database(bob.db.verification.utils.Database):
if model_ids is None:
for c in self.memory_db[protocol]['comparisons']:
objects.extend(self.memory_db[protocol]['comparison-templates'][c])
objects.extend(self.memory_db[protocol]['comparison-templates'][c].files)
else:
for m in model_ids:
objects.extend(self.memory_db[protocol]['comparison-templates'][m])
objects.extend(self.memory_db[protocol]['comparison-templates'][m].files)
if 'probe' in purposes:
if(model_ids is None):
raise ValueError("`model_ids` parameter required for the protocol `{0}`. For the comparison protocols, each model has an specific set of probes.".format(protocol))
for t in self.memory_db[protocol]['comparison-templates']:
objects.extend(self.memory_db[protocol]['comparison-templates'][t].files)
#import ipdb; ipdb.set_trace();
#raise ValueError("`model_ids` parameter required for the protocol `{0}`. For the comparison protocols, each model has an specific set of probes.".format(protocol))
else:
for c in model_ids:
for probe in self.memory_db[protocol]['comparisons'][c]:
objects.extend(self.memory_db[protocol]['comparison-templates'][probe])
objects.extend(self.memory_db[protocol]['comparison-templates'][probe].files)
# we have collected all queries, now extract the File objects
......@@ -355,15 +370,29 @@ class Database(bob.db.verification.utils.Database):
Note that the images of the database will be ignored, when this option is selected.
"""
#TODO: SET purposes and group as IGNORED
# check that every parameter is as expected
#groups = self.check_parameters_for_validity(groups, "group", ["dev","world"])
#purposes = self.check_parameters_for_validity(purposes, "purpose", ["enroll","probe"])
protocol = self.check_parameter_for_validity(protocol, "protocol", self.protocol_names())
return model_ids(self, protocol='search_split1')
purposes = self.check_parameters_for_validity(purposes, "purpose", ["enroll","probe"])
protocol = self.check_parameter_for_validity(protocol, "protocol", self.protocol_names())
templates = []
self._load_data(protocol, "dev", "enroll")
self._load_data(protocol, "dev", "probe")
for p in purposes:
for m in model_ids:
if "probe" in p:
template_ids = self.model_ids(groups="dev", protocol=protocol, purposes=p, model_ids=[m])
if "search" in protocol:
for t in template_ids:
templates.append(self.memory_db[protocol][p][t])
else:
for t in template_ids:
templates.append(self.memory_db[protocol]['comparison-templates'][t])
else:
templates.extend(self.memory_db[protocol][p][m])
return templates
......
......@@ -26,6 +26,24 @@ import os
from bob.db.verification.utils import File
class Template():
"""A ``Template`` contains a list of :py:class:`File` objects belonging to the same subject (there might be several templates per subject).
These are listed in the ``self.files`` field.
A ``Template`` can serve for training, model enrollment, or for probing.
Each template belongs specifically to a certain protocol, as the template_id in the original file lists might differ for different protocols.
The according :py:class:`ProtocolPurpose` can be obtained using the ``self.protocol_purpose`` after creation of the database.
Note that the ``template_id`` corresponds to the template_id of the file lists, while the ``id`` is only used as a un
ique key for querying the database.
For convenience, the template also contains a ``path``, which is a concatenation of the first :py:attr:`File.media_id
` of the first file, and the ``self.template_id``, making it unique (at least per protocol).
"""
def __init__(self, template_id, subject_id, files):
self.id = template_id
self.client_id = subject_id
assert isinstance(files,list)
self.files = files
self.path = "%s-%s" % (files[0].media_id, template_id)
def read_file(filename):
"""Reads the given file and yields the template id, the subject id and path_id (path + sighting_id)"""
......@@ -53,8 +71,9 @@ def read_file(filename):
file_obj.annotations = annotations
file_obj.extension = extension
yield template_id, file_obj
file_obj.media_id = splits[3]
yield template_id, client_id, file_obj
def get_comparisons(filename):
......@@ -96,13 +115,13 @@ def get_templates(filename, verbose=True):
"""
templates = {}
for template_id, file_obj in read_file(filename):
for template_id, client_id, file_obj in read_file(filename):
# create template with given IDs
if template_id not in templates:
templates[template_id] = [file_obj]
templates[template_id] = Template(template_id,client_id,[file_obj])
else:
templates[template_id].append(file_obj)
templates[template_id].files.append(file_obj)
return templates
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment