diff --git a/bob/bio/base/database/filelist/query.py b/bob/bio/base/database/filelist/query.py index 8ddc505e65a8be472c1fb216670102c72745866b..5e5cdf8c2c800585ca961927aa1380e888cc820d 100644 --- a/bob/bio/base/database/filelist/query.py +++ b/bob/bio/base/database/filelist/query.py @@ -151,10 +151,8 @@ class FileListBioDatabase(ZTBioDatabase): # self.original_directory = original_directory # self.original_extension = original_extension self.bio_file_class = bio_file_class - - self.m_annotation_directory = annotation_directory - self.m_annotation_extension = annotation_extension - self.m_annotation_type = annotation_type + self.keep_read_lists_in_memory=keep_read_lists_in_memory + self.list_readers = {} self.m_base_dir = os.path.abspath(filelists_directory) if not os.path.isdir(self.m_base_dir): @@ -228,7 +226,11 @@ class FileListBioDatabase(ZTBioDatabase): raise ValueError( "Unable to determine, which way of probing should be used, since this is not consistent accross protocols. Please specify.") - self.m_list_reader = ListReader(keep_read_lists_in_memory) + + def _list_reader(self, protocol): + if protocol not in self.list_readers: + self.list_readers[protocol] = ListReader(self.keep_read_lists_in_memory) + return self.list_readers[protocol] def _make_bio(self, files): return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files] @@ -362,7 +364,7 @@ class FileListBioDatabase(ZTBioDatabase): protocol = self.protocol for group in groups: - model_dict = self.m_list_reader.read_models(self.get_list_file(group, 'for_models', protocol), group, + model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_models', protocol), group, 'for_models') if model_id in model_dict: return model_dict[model_id] @@ -392,7 +394,7 @@ class FileListBioDatabase(ZTBioDatabase): protocol = self.protocol for group in groups: - model_dict = self.m_list_reader.read_models(self.get_list_file(group, 'for_tnorm', protocol), group, + model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_tnorm', protocol), group, 'for_tnorm') if t_model_id in model_dict: return model_dict[t_model_id] @@ -457,7 +459,7 @@ class FileListBioDatabase(ZTBioDatabase): protocol = protocol or self.protocol # read all lists for all groups and extract the model ids for group in groups: - files = self.m_list_reader.read_list(self.get_list_file(group, type, protocol), group, type) + files = self._list_reader(protocol).read_list(self.get_list_file(group, type, protocol), group, type) for file in files: ids.add(file.client_id) return ids @@ -526,7 +528,7 @@ class FileListBioDatabase(ZTBioDatabase): protocol = protocol or self.protocol # read all lists for all groups and extract the model ids for group in groups: - dict = self.m_list_reader.read_models(self.get_list_file(group, type, protocol), group, type) + dict = self._list_reader(protocol).read_models(self.get_list_file(group, type, protocol), group, type) ids.update(dict.keys()) return list(ids) @@ -620,28 +622,28 @@ class FileListBioDatabase(ZTBioDatabase): lists = [] probe_lists = [] if 'world' in groups: - lists.append(self.m_list_reader.read_list(self.get_list_file('world', protocol=protocol), 'world')) + lists.append(self._list_reader(protocol).read_list(self.get_list_file('world', protocol=protocol), 'world')) if 'optional_world_1' in groups: - lists.append(self.m_list_reader.read_list(self.get_list_file('optional_world_1', protocol=protocol), + lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_1', protocol=protocol), 'optional_world_1')) if 'optional_world_2' in groups: - lists.append(self.m_list_reader.read_list(self.get_list_file('optional_world_2', protocol=protocol), + lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_2', protocol=protocol), 'optional_world_2')) for group in ('dev', 'eval'): if group in groups: if 'enroll' in purposes: lists.append( - self.m_list_reader.read_list(self.get_list_file(group, 'for_models', protocol=protocol), group, + self._list_reader(protocol).read_list(self.get_list_file(group, 'for_models', protocol=protocol), group, 'for_models')) if 'probe' in purposes: if self.m_use_dense_probes: probe_lists.append( - self.m_list_reader.read_list(self.get_list_file(group, 'for_probes', protocol=protocol), + self._list_reader(protocol).read_list(self.get_list_file(group, 'for_probes', protocol=protocol), group, 'for_probes')) else: probe_lists.append( - self.m_list_reader.read_list(self.get_list_file(group, 'for_scores', protocol=protocol), + self._list_reader(protocol).read_list(self.get_list_file(group, 'for_scores', protocol=protocol), group, 'for_scores')) # now, go through the lists and filter the elements @@ -713,7 +715,7 @@ class FileListBioDatabase(ZTBioDatabase): # we assume that there is no duplicate file here... retval = [] for group in groups: - for file in self.m_list_reader.read_list(self.get_list_file(group, 'for_tnorm', protocol), group, + for file in self._list_reader(protocol).read_list(self.get_list_file(group, 'for_tnorm', protocol), group, 'for_tnorm'): if model_ids is None or file._model_id in model_ids: retval.append(file) @@ -742,7 +744,7 @@ class FileListBioDatabase(ZTBioDatabase): retval = [] for group in groups: retval.extend([file for file in - self.m_list_reader.read_list(self.get_list_file(group, 'for_znorm', protocol), group, + self._list_reader(protocol).read_list(self.get_list_file(group, 'for_znorm', protocol), group, 'for_znorm')]) return self._make_bio(retval) @@ -760,14 +762,14 @@ class FileListBioDatabase(ZTBioDatabase): Return value The annotations as a dictionary: {'reye':(re_y,re_x), 'leye':(le_y,le_x)} """ - if self.m_annotation_directory is None: + if self.annotation_directory is None: return None # since the file id is equal to the file name, we can simply use it - annotation_file = os.path.join(self.m_annotation_directory, file.id + self.m_annotation_extension) + annotation_file = os.path.join(self.annotation_directory, file.id + self.annotation_extension) # return the annotations as read from file - return bob.db.base.read_annotation_file(annotation_file, self.m_annotation_type) + return bob.db.base.read_annotation_file(annotation_file, self.annotation_type) def original_file_name(self, file, check_existence=True): """Returns the original file name of the given file.