Skip to content
Snippets Groups Projects
Commit 90e2ea05 authored by Manuel Günther's avatar Manuel Günther
Browse files

Now creating one list reader per protocol

parent 7910a29c
No related branches found
No related tags found
1 merge request!59Resolve "strange behaviour when retrieving objects from bob.bio.base filelist database query"
Pipeline #
......@@ -151,10 +151,8 @@ class FileListBioDatabase(ZTBioDatabase):
# self.original_directory = original_directory
# self.original_extension = original_extension
self.bio_file_class = bio_file_class
self.m_annotation_directory = annotation_directory
self.m_annotation_extension = annotation_extension
self.m_annotation_type = annotation_type
self.keep_read_lists_in_memory=keep_read_lists_in_memory
self.list_readers = {}
self.m_base_dir = os.path.abspath(filelists_directory)
if not os.path.isdir(self.m_base_dir):
......@@ -228,7 +226,11 @@ class FileListBioDatabase(ZTBioDatabase):
raise ValueError(
"Unable to determine, which way of probing should be used, since this is not consistent accross protocols. Please specify.")
self.m_list_reader = ListReader(keep_read_lists_in_memory)
def _list_reader(self, protocol):
if protocol not in self.list_readers:
self.list_readers[protocol] = ListReader(self.keep_read_lists_in_memory)
return self.list_readers[protocol]
def _make_bio(self, files):
return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files]
......@@ -362,7 +364,7 @@ class FileListBioDatabase(ZTBioDatabase):
protocol = self.protocol
for group in groups:
model_dict = self.m_list_reader.read_models(self.get_list_file(group, 'for_models', protocol), group,
model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_models', protocol), group,
'for_models')
if model_id in model_dict:
return model_dict[model_id]
......@@ -392,7 +394,7 @@ class FileListBioDatabase(ZTBioDatabase):
protocol = self.protocol
for group in groups:
model_dict = self.m_list_reader.read_models(self.get_list_file(group, 'for_tnorm', protocol), group,
model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_tnorm', protocol), group,
'for_tnorm')
if t_model_id in model_dict:
return model_dict[t_model_id]
......@@ -457,7 +459,7 @@ class FileListBioDatabase(ZTBioDatabase):
protocol = protocol or self.protocol
# read all lists for all groups and extract the model ids
for group in groups:
files = self.m_list_reader.read_list(self.get_list_file(group, type, protocol), group, type)
files = self._list_reader(protocol).read_list(self.get_list_file(group, type, protocol), group, type)
for file in files:
ids.add(file.client_id)
return ids
......@@ -526,7 +528,7 @@ class FileListBioDatabase(ZTBioDatabase):
protocol = protocol or self.protocol
# read all lists for all groups and extract the model ids
for group in groups:
dict = self.m_list_reader.read_models(self.get_list_file(group, type, protocol), group, type)
dict = self._list_reader(protocol).read_models(self.get_list_file(group, type, protocol), group, type)
ids.update(dict.keys())
return list(ids)
......@@ -620,28 +622,28 @@ class FileListBioDatabase(ZTBioDatabase):
lists = []
probe_lists = []
if 'world' in groups:
lists.append(self.m_list_reader.read_list(self.get_list_file('world', protocol=protocol), 'world'))
lists.append(self._list_reader(protocol).read_list(self.get_list_file('world', protocol=protocol), 'world'))
if 'optional_world_1' in groups:
lists.append(self.m_list_reader.read_list(self.get_list_file('optional_world_1', protocol=protocol),
lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_1', protocol=protocol),
'optional_world_1'))
if 'optional_world_2' in groups:
lists.append(self.m_list_reader.read_list(self.get_list_file('optional_world_2', protocol=protocol),
lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_2', protocol=protocol),
'optional_world_2'))
for group in ('dev', 'eval'):
if group in groups:
if 'enroll' in purposes:
lists.append(
self.m_list_reader.read_list(self.get_list_file(group, 'for_models', protocol=protocol), group,
self._list_reader(protocol).read_list(self.get_list_file(group, 'for_models', protocol=protocol), group,
'for_models'))
if 'probe' in purposes:
if self.m_use_dense_probes:
probe_lists.append(
self.m_list_reader.read_list(self.get_list_file(group, 'for_probes', protocol=protocol),
self._list_reader(protocol).read_list(self.get_list_file(group, 'for_probes', protocol=protocol),
group, 'for_probes'))
else:
probe_lists.append(
self.m_list_reader.read_list(self.get_list_file(group, 'for_scores', protocol=protocol),
self._list_reader(protocol).read_list(self.get_list_file(group, 'for_scores', protocol=protocol),
group, 'for_scores'))
# now, go through the lists and filter the elements
......@@ -713,7 +715,7 @@ class FileListBioDatabase(ZTBioDatabase):
# we assume that there is no duplicate file here...
retval = []
for group in groups:
for file in self.m_list_reader.read_list(self.get_list_file(group, 'for_tnorm', protocol), group,
for file in self._list_reader(protocol).read_list(self.get_list_file(group, 'for_tnorm', protocol), group,
'for_tnorm'):
if model_ids is None or file._model_id in model_ids:
retval.append(file)
......@@ -742,7 +744,7 @@ class FileListBioDatabase(ZTBioDatabase):
retval = []
for group in groups:
retval.extend([file for file in
self.m_list_reader.read_list(self.get_list_file(group, 'for_znorm', protocol), group,
self._list_reader(protocol).read_list(self.get_list_file(group, 'for_znorm', protocol), group,
'for_znorm')])
return self._make_bio(retval)
......@@ -760,14 +762,14 @@ class FileListBioDatabase(ZTBioDatabase):
Return value
The annotations as a dictionary: {'reye':(re_y,re_x), 'leye':(le_y,le_x)}
"""
if self.m_annotation_directory is None:
if self.annotation_directory is None:
return None
# since the file id is equal to the file name, we can simply use it
annotation_file = os.path.join(self.m_annotation_directory, file.id + self.m_annotation_extension)
annotation_file = os.path.join(self.annotation_directory, file.id + self.annotation_extension)
# return the annotations as read from file
return bob.db.base.read_annotation_file(annotation_file, self.m_annotation_type)
return bob.db.base.read_annotation_file(annotation_file, self.annotation_type)
def original_file_name(self, file, check_existence=True):
"""Returns the original file name of the given file.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment