From 219546128d59fe5b36c12a645a9e46515f982709 Mon Sep 17 00:00:00 2001 From: Manuel Gunther <siebenkopf@googlemail.com> Date: Mon, 23 Jan 2017 16:16:43 -0700 Subject: [PATCH] Implemented better behavior when protocol is present; added possibility to only have 'dev' group --- bob/bio/base/database/filelist/query.py | 202 ++++++++---------- .../data/example_filelist2/dev/for_models.lst | 12 ++ .../data/example_filelist2/dev/for_scores.lst | 14 ++ bob/bio/base/test/test_filelist.py | 12 +- 4 files changed, 125 insertions(+), 115 deletions(-) create mode 100644 bob/bio/base/test/data/example_filelist2/dev/for_models.lst create mode 100644 bob/bio/base/test/data/example_filelist2/dev/for_scores.lst diff --git a/bob/bio/base/database/filelist/query.py b/bob/bio/base/database/filelist/query.py index fe034c02..09ae85f8 100644 --- a/bob/bio/base/database/filelist/query.py +++ b/bob/bio/base/database/filelist/query.py @@ -182,81 +182,45 @@ class FileListBioDatabase(ZTBioDatabase): # Z-Norm files format: filename client_id self.m_znorm_filename = znorm_filename if znorm_filename is not None else 'for_znorm.lst' - # decide, which scoring type we have: - if probes_filename is not None and scores_filename is None: - self.m_use_dense_probes = True - elif probes_filename is None and scores_filename is not None: - self.m_use_dense_probes = False - elif use_dense_probe_file_list is not None: - self.m_use_dense_probes = use_dense_probe_file_list - # Then direct path to a given protocol - elif os.path.isdir(os.path.join(self.get_base_directory(), self.m_dev_subdir)) or os.path.isfile( - os.path.join(self.get_base_directory(), self.m_world_filename)): - if os.path.exists(self.get_list_file('dev', 'for_probes')) and not os.path.exists( - self.get_list_file('dev', 'for_scores')): - self.m_use_dense_probes = True - elif not os.path.exists(self.get_list_file('dev', 'for_probes')) and os.path.exists( - self.get_list_file('dev', 'for_scores')): - self.m_use_dense_probes = False - else: - raise ValueError("Unable to determine, which way of probing should be used. Please specify.") - # Then path to a directory that contains several subdirectories (one for each protocol) - else: - # Look at subdirectories for each protocol - protocols = [p for p in os.listdir(self.get_base_directory()) if - os.path.isdir(os.path.join(self.get_base_directory(), p))] - if len(protocols) == 0: - raise ValueError( - "Unable to determine, which way of probing should be used (no protocol directories found). Please specify.") - list_use_dense_probes = [] - for p in protocols: - if os.path.exists(self.get_list_file('dev', 'for_probes', p)) and not os.path.exists( - self.get_list_file('dev', 'for_scores', p)): - use_dense_probes = True - elif not os.path.exists(self.get_list_file('dev', 'for_probes', p)) and os.path.exists( - self.get_list_file('dev', 'for_scores', p)): - use_dense_probes = False - else: - raise ValueError( - "Unable to determine, which way of probing should be used, looking at the protocol (directory) '%s'. Please specify." % p) - list_use_dense_probes.append(use_dense_probes) - if len(set(list_use_dense_probes)) == 1: - self.m_use_dense_probes = list_use_dense_probes[0] - else: - raise ValueError( - "Unable to determine, which way of probing should be used, since this is not consistent accross protocols. Please specify.") + self.m_use_dense_probe_file_list = use_dense_probe_file_list def _list_reader(self, protocol): - if protocol not in self.list_readers: - if protocol is not None: - protocol_dir = os.path.join(self.get_base_directory(), protocol) - if not os.path.isdir(protocol_dir): - raise ValueError("The directory %s for the given protocol '%s' does not exist" % (protocol_dir, protocol)) - self.list_readers[protocol] = ListReader(self.keep_read_lists_in_memory) + if protocol not in self.list_readers: + if protocol is not None: + protocol_dir = os.path.join(self.get_base_directory(), protocol) + if not os.path.isdir(protocol_dir): + raise ValueError("The directory %s for the given protocol '%s' does not exist" % (protocol_dir, protocol)) + self.list_readers[protocol] = ListReader(self.keep_read_lists_in_memory) - return self.list_readers[protocol] + return self.list_readers[protocol] def _make_bio(self, files): return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files] def all_files(self, groups=['dev']): - files = self.objects(groups, self.protocol, None, None, **self.all_files_options) + files = self.objects(groups, self.protocol, **self.all_files_options) # add all files that belong to the ZT-norm for group in groups: if group == 'world': continue if self.implements_zt(self.protocol, group): - files += self.tobjects(group, self.protocol, None) + files += self.tobjects(group, self.protocol) files += self.zobjects(group, self.protocol, **self.z_probe_options) return self.sort(self._make_bio(files)) - def groups(self, protocol=None): + def groups(self, protocol=None, add_world=True, add_subworld=True): """This function returns the list of groups for this database. protocol : str or ``None`` The protocol for which the groups should be retrieved. + add_world : bool + Add the world groups? + + add_subworld : bool + Add the sub-world groups? Only valid, when ``add_world=True`` + Returns: a list of groups """ @@ -267,23 +231,27 @@ class FileListBioDatabase(ZTBioDatabase): groups.append('dev') if os.path.isdir(os.path.join(self.get_base_directory(), protocol, self.m_eval_subdir)): groups.append('eval') - if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_world_filename)): - groups.append('world') - if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_1_filename)): - groups.append('optional_world_1') - if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_2_filename)): - groups.append('optional_world_2') + if add_world: + if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_world_filename)): + groups.append('world') + if add_world and add_subworld: + if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_1_filename)): + groups.append('optional_world_1') + if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_2_filename)): + groups.append('optional_world_2') else: if os.path.isdir(os.path.join(self.get_base_directory(), self.m_dev_subdir)): groups.append('dev') if os.path.isdir(os.path.join(self.get_base_directory(), self.m_eval_subdir)): groups.append('eval') - if os.path.isfile(os.path.join(self.get_base_directory(), self.m_world_filename)): - groups.append('world') - if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_1_filename)): - groups.append('optional_world_1') - if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_2_filename)): - groups.append('optional_world_2') + if add_world: + if os.path.isfile(os.path.join(self.get_base_directory(), self.m_world_filename)): + groups.append('world') + if add_world and add_subworld: + if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_1_filename)): + groups.append('optional_world_1') + if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_2_filename)): + groups.append('optional_world_2') return groups def implements_zt(self, protocol=None, groups=None): @@ -300,16 +268,36 @@ class FileListBioDatabase(ZTBioDatabase): Returns: ``True`` if the all file lists for ZT score normalization exist, otherwise ``False``. """ - groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval')) - protocol = protocol or self.protocol + groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False)) + for group in groups: for t in ['for_tnorm', 'for_znorm']: - if not os.path.exists(self.get_list_file(group, t, protocol)): + if not os.path.exists(self._get_list_file(group, t, protocol)): return False # all files exist return True + def uses_dense_probe_file(self, protocol): + """Determines if a dense probe file list is used based on the existence of parameters.""" + # return, whatever was specified in constructor, if not None + if self.m_use_dense_probe_file_list is not None: + return self.m_use_dense_probe_file_list + + # check the existence of the files + probes = True + scores = True + for group in self.groups(protocol, add_world=False): + probes = probes and os.path.exists(self._get_list_file(group, type='for_probes', protocol=protocol)) + scores = probes and os.path.exists(self._get_list_file(group, type='for_scores', protocol=protocol)) + # decide, which score files are available + if probes and not scores: + return True + if not probes and scores: + return False + raise ValueError("Unable to determine, which way of probing should be used. Please specify.") + + def get_base_directory(self): """Returns the base directory where the filelists defining the database are located.""" @@ -322,7 +310,7 @@ class FileListBioDatabase(ZTBioDatabase): if not os.path.isdir(self.filelists_directory): raise RuntimeError('Invalid directory specified %s.' % (self.filelists_directory)) - def get_list_file(self, group, type=None, protocol=None): + def _get_list_file(self, group, type=None, protocol=None): if protocol: base_directory = os.path.join(self.get_base_directory(), protocol) else: @@ -361,15 +349,13 @@ class FileListBioDatabase(ZTBioDatabase): Returns: The client id for the given model id, if found. """ - # compatibility reasons - groups = group - groups = self.check_parameters_for_validity(groups, "group", - ('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'), - default_parameters=('dev', 'eval', 'world')) - protocol = self.protocol + groups = self.check_parameters_for_validity(group, "group", + self.groups(protocol), + default_parameters=self.groups(protocol, add_subworld=False)) + for group in groups: - model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_models', protocol), group, + model_dict = self._list_reader(protocol).read_models(self._get_list_file(group, 'for_models', protocol), group, 'for_models') if model_id in model_dict: return model_dict[model_id] @@ -394,12 +380,11 @@ class FileListBioDatabase(ZTBioDatabase): Returns: The client id for the given model id of a T-Norm model, if found. """ - groups = group - groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval')) - protocol = self.protocol + groups = self.check_parameters_for_validity(group, "group", self.groups(protocol, add_world=False)) + for group in groups: - model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_tnorm', protocol), group, + model_dict = self._list_reader(protocol).read_models(self._get_list_file(group, 'for_tnorm', protocol), group, 'for_tnorm') if t_model_id in model_dict: return model_dict[t_model_id] @@ -464,7 +449,7 @@ class FileListBioDatabase(ZTBioDatabase): protocol = protocol or self.protocol # read all lists for all groups and extract the model ids for group in groups: - files = self._list_reader(protocol).read_list(self.get_list_file(group, type, protocol), group, type) + files = self._list_reader(protocol).read_list(self._get_list_file(group, type, protocol), group, type) for file in files: ids.add(file.client_id) return ids @@ -485,8 +470,8 @@ class FileListBioDatabase(ZTBioDatabase): protocol = protocol or self.protocol groups = self.check_parameters_for_validity(groups, "group", - ('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'), - default_parameters=('dev', 'eval', 'world')) + self.groups(protocol), + default_parameters=self.groups(protocol, add_subworld=False)) return self.__client_id_list__(groups, 'for_models', protocol) @@ -505,7 +490,7 @@ class FileListBioDatabase(ZTBioDatabase): """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval')) + groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False)) return self.__client_id_list__(groups, 'for_tnorm', protocol) @@ -524,7 +509,7 @@ class FileListBioDatabase(ZTBioDatabase): """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval')) + groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False)) return self.__client_id_list__(groups, 'for_znorm', protocol) @@ -533,7 +518,7 @@ class FileListBioDatabase(ZTBioDatabase): protocol = protocol or self.protocol # read all lists for all groups and extract the model ids for group in groups: - dict = self._list_reader(protocol).read_models(self.get_list_file(group, type, protocol), group, type) + dict = self._list_reader(protocol).read_models(self._get_list_file(group, type, protocol), group, type) ids.update(dict.keys()) return list(ids) @@ -551,10 +536,7 @@ class FileListBioDatabase(ZTBioDatabase): Returns: A list containing all the model ids which have the given properties. """ protocol = protocol or self.protocol - - groups = self.check_parameters_for_validity(groups, "group", - ('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'), - default_parameters=('dev', 'eval', 'world')) + groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol=protocol)) return self.__model_id_list__(groups, 'for_models', protocol) @@ -572,8 +554,7 @@ class FileListBioDatabase(ZTBioDatabase): Returns: A list containing all the T-Norm model ids belonging to the given group. """ protocol = protocol or self.protocol - - groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval')) + groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False)) return self.__model_id_list__(groups, 'for_tnorm', protocol) @@ -598,8 +579,7 @@ class FileListBioDatabase(ZTBioDatabase): groups : str or [str] or ``None`` One of the groups ("dev", "eval", "world", "optional_world_1", "optional_world_2") or a tuple with several of them. - If 'None' is given (this is the default), it is considered the same as a - tuple with all possible values. + If 'None' is given (this is the default), it is considered to be the existing subset of ``("world", "dev", "eval")``. classes : str or [str] or ``None`` The classes (types of accesses) to be retrieved ('client', 'impostor') @@ -611,13 +591,13 @@ class FileListBioDatabase(ZTBioDatabase): """ protocol = protocol or self.protocol - if self.m_use_dense_probes and classes is not None: + if self.uses_dense_probe_file(protocol) and classes is not None: raise ValueError("To be able to use the 'classes' keyword, please use the 'for_scores.lst' list file.") purposes = self.check_parameters_for_validity(purposes, "purpose", ('enroll', 'probe')) groups = self.check_parameters_for_validity(groups, "group", - ('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'), - default_parameters=('dev', 'eval', 'world')) + self.groups(protocol), + default_parameters=self.groups(protocol, add_subworld=False)) classes = self.check_parameters_for_validity(classes, "class", ('client', 'impostor')) if isinstance(model_ids, six.string_types): @@ -627,29 +607,26 @@ class FileListBioDatabase(ZTBioDatabase): lists = [] probe_lists = [] if 'world' in groups: - lists.append(self._list_reader(protocol).read_list(self.get_list_file('world', protocol=protocol), 'world')) + lists.append(self._list_reader(protocol).read_list(self._get_list_file('world', protocol=protocol), 'world')) if 'optional_world_1' in groups: - lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_1', protocol=protocol), + lists.append(self._list_reader(protocol).read_list(self._get_list_file('optional_world_1', protocol=protocol), 'optional_world_1')) if 'optional_world_2' in groups: - lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_2', protocol=protocol), + lists.append(self._list_reader(protocol).read_list(self._get_list_file('optional_world_2', protocol=protocol), 'optional_world_2')) for group in ('dev', 'eval'): if group in groups: if 'enroll' in purposes: lists.append( - self._list_reader(protocol).read_list(self.get_list_file(group, 'for_models', protocol=protocol), group, - 'for_models')) + self._list_reader(protocol).read_list(self._get_list_file(group, 'for_models', protocol=protocol), group, 'for_models')) if 'probe' in purposes: - if self.m_use_dense_probes: + if self.uses_dense_probe_file(protocol): probe_lists.append( - self._list_reader(protocol).read_list(self.get_list_file(group, 'for_probes', protocol=protocol), - group, 'for_probes')) + self._list_reader(protocol).read_list(self._get_list_file(group, 'for_probes', protocol=protocol), group, 'for_probes')) else: probe_lists.append( - self._list_reader(protocol).read_list(self.get_list_file(group, 'for_scores', protocol=protocol), - group, 'for_scores')) + self._list_reader(protocol).read_list(self._get_list_file(group, 'for_scores', protocol=protocol), group, 'for_scores')) # now, go through the lists and filter the elements @@ -668,7 +645,7 @@ class FileListBioDatabase(ZTBioDatabase): # probe files; filter by model id and by class for list in probe_lists: - if self.m_use_dense_probes: + if self.uses_dense_probe_file(protocol): # dense probing is used; do not filter over the model ids and not over the classes # -> just add all probe files for file in list: @@ -710,8 +687,7 @@ class FileListBioDatabase(ZTBioDatabase): Returns: A list of :py:class:`BioFile` objects considering all the filtering criteria. """ protocol = protocol or self.protocol - - groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval')) + groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False)) if (isinstance(model_ids, six.string_types)): model_ids = (model_ids,) @@ -720,7 +696,7 @@ class FileListBioDatabase(ZTBioDatabase): # we assume that there is no duplicate file here... retval = [] for group in groups: - for file in self._list_reader(protocol).read_list(self.get_list_file(group, 'for_tnorm', protocol), group, + for file in self._list_reader(protocol).read_list(self._get_list_file(group, 'for_tnorm', protocol), group, 'for_tnorm'): if model_ids is None or file._model_id in model_ids: retval.append(file) @@ -742,14 +718,14 @@ class FileListBioDatabase(ZTBioDatabase): """ protocol = protocol or self.protocol - groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval')) + groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False)) # iterate over the lists and extract the files # we assume that there is no duplicate file here... retval = [] for group in groups: retval.extend([file for file in - self._list_reader(protocol).read_list(self.get_list_file(group, 'for_znorm', protocol), group, + self._list_reader(protocol).read_list(self._get_list_file(group, 'for_znorm', protocol), group, 'for_znorm')]) return self._make_bio(retval) diff --git a/bob/bio/base/test/data/example_filelist2/dev/for_models.lst b/bob/bio/base/test/data/example_filelist2/dev/for_models.lst new file mode 100644 index 00000000..7bdbf751 --- /dev/null +++ b/bob/bio/base/test/data/example_filelist2/dev/for_models.lst @@ -0,0 +1,12 @@ +data/model3_session1_sample1 3 3 +data/model3_session1_sample2 3 3 +data/model3_session1_sample3 3 3 +data/model3_session2_sample1 3 3 +data/model4_session1_sample1 4 4 +data/model4_session1_sample2 4 4 +data/model4_session1_sample3 4 4 +data/model4_session2_sample1 4 4 +data/model5_session1_sample1 5 5 +data/model5_session1_sample2 5 5 +data/model5_session1_sample3 5 5 +data/model5_session2_sample1 5 5 diff --git a/bob/bio/base/test/data/example_filelist2/dev/for_scores.lst b/bob/bio/base/test/data/example_filelist2/dev/for_scores.lst new file mode 100644 index 00000000..340e7ec1 --- /dev/null +++ b/bob/bio/base/test/data/example_filelist2/dev/for_scores.lst @@ -0,0 +1,14 @@ +data/model3_session3_sample1 3 3 3 +data/model3_session3_sample2 3 3 3 +data/model3_session3_sample3 3 3 3 +data/model3_session4_sample1 3 3 3 +data/model4_session3_sample1 3 3 4 +data/model4_session3_sample2 3 3 4 +data/model4_session3_sample1 4 4 4 +data/model4_session3_sample2 4 4 4 +data/model4_session3_sample3 4 4 4 +data/model4_session4_sample1 4 4 4 +data/model3_session3_sample1 4 4 3 +data/model3_session3_sample2 4 4 3 +data/model5_session3_sample1 5 3 5 +data/model5_session3_sample1 5 5 5 diff --git a/bob/bio/base/test/test_filelist.py b/bob/bio/base/test/test_filelist.py index fcaa4627..405591a8 100644 --- a/bob/bio/base/test/test_filelist.py +++ b/bob/bio/base/test/test_filelist.py @@ -125,11 +125,19 @@ def test_query_protocol(): assert db.client_id_from_model_id('6', group=None) == '6' assert db.client_id_from_t_model_id('7', group=None) == '7' - nose.tools.assert_raises(ValueError, db.objects, protocol='non-existent') + + # check other protocols + assert len(db.objects(protocol='non-existent')) == 0 + + prot = 'example_filelist2' + assert len(db.model_ids_with_protocol(protocol=prot)) == 3 # 3 model ids for dev only + nose.tools.assert_raises(ValueError, db.model_ids_with_protocol, protocol=prot, groups='eval') # eval does not exist for this protocol + assert len(db.objects(protocol=prot, groups='dev', purposes='enroll')) == 12 + assert len(db.objects(protocol=prot, groups='dev', purposes='probe')) == 9 def test_query_dense(): - db = FileListBioDatabase(example_dir, 'test', probes_filename='for_probes.lst') + db = FileListBioDatabase(example_dir, 'test', use_dense_probe_file_list=True) assert len(db.objects(groups='world')) == 8 # 8 samples in the world set -- GitLab