diff --git a/bob/bio/base/database/database.py b/bob/bio/base/database/database.py index 01ec10908db4fbf00e70c88ab02304995e1de111..cf9ad3664b04de305bec24de975dd23413f21e85 100644 --- a/bob/bio/base/database/database.py +++ b/bob/bio/base/database/database.py @@ -367,7 +367,7 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): """ return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol)) - def all_files(self, groups=None): + def all_files(self, groups=None, **kwargs): """all_files(groups=None) -> files Returns all files of the database, respecting the current protocol. @@ -379,6 +379,8 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): The groups to get the data for. If ``None``, data for all groups is returned. + kwargs: ignored + **Returns:** files : [:py:class:`bob.bio.base.database.BioFile`] @@ -637,7 +639,7 @@ class ZTBioDatabase(BioDatabase): """ raise NotImplementedError("This function must be implemented in your derived class.") - def all_files(self, groups=['dev']): + def all_files(self, groups=['dev'], add_zt_files=True): """all_files(groups=None) -> files Returns all files of the database, including those for ZT norm, respecting the current protocol. @@ -649,6 +651,9 @@ class ZTBioDatabase(BioDatabase): The groups to get the data for. If ``None``, data for all groups is returned. + add_zt_files: bool + If set (the default), files for ZT score normalization are added. + **Returns:** files : [:py:class:`bob.bio.base.database.BioFile`] @@ -657,11 +662,12 @@ class ZTBioDatabase(BioDatabase): files = self.objects(protocol=self.protocol, groups=groups, **self.all_files_options) # add all files that belong to the ZT-norm - for group in groups: - if group == 'world': - continue - files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None) - files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options) + if add_zt_files: + for group in groups: + if group == 'world': + continue + files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None) + files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options) return self.sort(files) @abc.abstractmethod diff --git a/bob/bio/base/database/filelist/query.py b/bob/bio/base/database/filelist/query.py index 2447c71aa213c9da463aa365d895aae3bbabcd7f..8085b0d40eda2ea395d38eacb61f0e962ce254f3 100644 --- a/bob/bio/base/database/filelist/query.py +++ b/bob/bio/base/database/filelist/query.py @@ -198,13 +198,15 @@ class FileListBioDatabase(ZTBioDatabase): def _make_bio(self, files): return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files] - def all_files(self, groups=['dev']): + def all_files(self, groups=['dev'], add_zt_files=True): files = self.objects(groups, self.protocol, **self.all_files_options) # add all files that belong to the ZT-norm for group in groups: if group == 'world': continue - if self.implements_zt(self.protocol, group): + if add_zt_files: + if not self.implements_zt(self.protocol, group): + raise ValueError("ZT score files are requested, but no such files are defined in group %s for protocol %s", group, self.protocol) files += self.tobjects(group, self.protocol) files += self.zobjects(group, self.protocol, **self.z_probe_options) return self.sort(self._make_bio(files)) diff --git a/bob/bio/base/tools/FileSelector.py b/bob/bio/base/tools/FileSelector.py index 0870ee71ff27160b37251a3e17f7b38d5e60f9f7..3f1a9baead720035be523c5970bc99251cf386a0 100644 --- a/bob/bio/base/tools/FileSelector.py +++ b/bob/bio/base/tools/FileSelector.py @@ -68,7 +68,8 @@ class FileSelector(object): score_directories, zt_score_directories = None, default_extension = '.hdf5', - compressed_extension = '' + compressed_extension = '', + zt_norm = False ): """Initialize the file selector object with the current configuration.""" @@ -89,6 +90,7 @@ class FileSelector(object): 'extracted' : extracted_directory, 'projected' : projected_directory } + self.zt_norm = zt_norm def uses_probe_file_sets(self): @@ -108,7 +110,7 @@ class FileSelector(object): ### List of files that will be used for all files def original_data_list(self, groups = None): """Returns the list of original ``BioFile`` objects that can be used for preprocessing.""" - return self.database.all_files(groups=groups) + return self.database.all_files(groups=groups,add_zt_files=self.zt_norm) def original_directory_and_extension(self): """Returns the directory and extension of the original files.""" @@ -116,7 +118,7 @@ class FileSelector(object): def annotation_list(self, groups = None): """Returns the list of annotations objects.""" - return self.database.all_files(groups=groups) + return self.database.all_files(groups=groups,add_zt_files=self.zt_norm) def get_annotations(self, annotation_file): """Returns the annotations of the given file.""" @@ -124,15 +126,15 @@ class FileSelector(object): def preprocessed_data_list(self, groups = None): """Returns the list of preprocessed data files.""" - return self.get_paths(self.database.all_files(groups=groups), "preprocessed") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "preprocessed") def feature_list(self, groups = None): """Returns the list of extracted feature files.""" - return self.get_paths(self.database.all_files(groups=groups), "extracted") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "extracted") def projected_list(self, groups = None): """Returns the list of projected feature files.""" - return self.get_paths(self.database.all_files(groups=groups), "projected") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "projected") ### Training lists diff --git a/bob/bio/base/tools/command_line.py b/bob/bio/base/tools/command_line.py index cf8010c159094aed99a1de3d29e0bc2c44e64d3c..0f42c5790e505c0eaaa79517385b0fd070730d4b 100644 --- a/bob/bio/base/tools/command_line.py +++ b/bob/bio/base/tools/command_line.py @@ -364,6 +364,7 @@ def initialize(parsers, command_line_parameters = None, skips = []): zt_score_directories = [os.path.join(args.temp_directory, protocol, s) for s in args.zt_directories], compressed_extension = '.tar.bz2' if args.write_compressed_score_files else '', default_extension = '.hdf5', + zt_norm = args.zt_norm ) return args