From 87b41c41fa90ccccb6c5a6e886ce81f58fa3d776 Mon Sep 17 00:00:00 2001 From: Manuel Gunther <siebenkopf@googlemail.com> Date: Tue, 4 Apr 2017 19:23:53 -0600 Subject: [PATCH] Implemented skipping of ZT files if not wanted --- bob/bio/base/database/database.py | 20 +++++++++++++------- bob/bio/base/database/filelist/query.py | 6 ++++-- bob/bio/base/tools/FileSelector.py | 14 ++++++++------ bob/bio/base/tools/command_line.py | 1 + 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/bob/bio/base/database/database.py b/bob/bio/base/database/database.py index 5bfb7978..2ab51564 100644 --- a/bob/bio/base/database/database.py +++ b/bob/bio/base/database/database.py @@ -370,7 +370,7 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): """ return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol)) - def all_files(self, groups=None): + def all_files(self, groups=None, **kwargs): """all_files(groups=None) -> files Returns all files of the database, respecting the current protocol. @@ -382,6 +382,8 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): The groups to get the data for. If ``None``, data for all groups is returned. + kwargs: ignored + **Returns:** files : [:py:class:`bob.bio.base.database.BioFile`] @@ -640,7 +642,7 @@ class ZTBioDatabase(BioDatabase): """ raise NotImplementedError("This function must be implemented in your derived class.") - def all_files(self, groups=['dev']): + def all_files(self, groups=['dev'], add_zt_files=True): """all_files(groups=None) -> files Returns all files of the database, including those for ZT norm, respecting the current protocol. @@ -652,6 +654,9 @@ class ZTBioDatabase(BioDatabase): The groups to get the data for. If ``None``, data for all groups is returned. + add_zt_files: bool + If set (the default), files for ZT score normalization are added. + **Returns:** files : [:py:class:`bob.bio.base.database.BioFile`] @@ -660,11 +665,12 @@ class ZTBioDatabase(BioDatabase): files = self.objects(protocol=self.protocol, groups=groups, **self.all_files_options) # add all files that belong to the ZT-norm - for group in groups: - if group == 'world': - continue - files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None) - files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options) + if add_zt_files: + for group in groups: + if group == 'world': + continue + files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None) + files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options) return self.sort(files) @abc.abstractmethod diff --git a/bob/bio/base/database/filelist/query.py b/bob/bio/base/database/filelist/query.py index dcaad932..3d3c4f11 100644 --- a/bob/bio/base/database/filelist/query.py +++ b/bob/bio/base/database/filelist/query.py @@ -198,13 +198,15 @@ class FileListBioDatabase(ZTBioDatabase): def _make_bio(self, files): return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files] - def all_files(self, groups=['dev']): + def all_files(self, groups=['dev'], add_zt_files=True): files = self.objects(groups, self.protocol, **self.all_files_options) # add all files that belong to the ZT-norm for group in groups: if group == 'world': continue - if self.implements_zt(self.protocol, group): + if add_zt_files: + if not self.implements_zt(self.protocol, group): + raise ValueError("ZT score files are requested, but no such files are defined in group %s for protocol %s", group, self.protocol) files += self.tobjects(group, self.protocol) files += self.zobjects(group, self.protocol, **self.z_probe_options) return self.sort(self._make_bio(files)) diff --git a/bob/bio/base/tools/FileSelector.py b/bob/bio/base/tools/FileSelector.py index 0870ee71..3f1a9bae 100644 --- a/bob/bio/base/tools/FileSelector.py +++ b/bob/bio/base/tools/FileSelector.py @@ -68,7 +68,8 @@ class FileSelector(object): score_directories, zt_score_directories = None, default_extension = '.hdf5', - compressed_extension = '' + compressed_extension = '', + zt_norm = False ): """Initialize the file selector object with the current configuration.""" @@ -89,6 +90,7 @@ class FileSelector(object): 'extracted' : extracted_directory, 'projected' : projected_directory } + self.zt_norm = zt_norm def uses_probe_file_sets(self): @@ -108,7 +110,7 @@ class FileSelector(object): ### List of files that will be used for all files def original_data_list(self, groups = None): """Returns the list of original ``BioFile`` objects that can be used for preprocessing.""" - return self.database.all_files(groups=groups) + return self.database.all_files(groups=groups,add_zt_files=self.zt_norm) def original_directory_and_extension(self): """Returns the directory and extension of the original files.""" @@ -116,7 +118,7 @@ class FileSelector(object): def annotation_list(self, groups = None): """Returns the list of annotations objects.""" - return self.database.all_files(groups=groups) + return self.database.all_files(groups=groups,add_zt_files=self.zt_norm) def get_annotations(self, annotation_file): """Returns the annotations of the given file.""" @@ -124,15 +126,15 @@ class FileSelector(object): def preprocessed_data_list(self, groups = None): """Returns the list of preprocessed data files.""" - return self.get_paths(self.database.all_files(groups=groups), "preprocessed") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "preprocessed") def feature_list(self, groups = None): """Returns the list of extracted feature files.""" - return self.get_paths(self.database.all_files(groups=groups), "extracted") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "extracted") def projected_list(self, groups = None): """Returns the list of projected feature files.""" - return self.get_paths(self.database.all_files(groups=groups), "projected") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "projected") ### Training lists diff --git a/bob/bio/base/tools/command_line.py b/bob/bio/base/tools/command_line.py index 5ddd7150..2cdc11f9 100644 --- a/bob/bio/base/tools/command_line.py +++ b/bob/bio/base/tools/command_line.py @@ -460,6 +460,7 @@ def initialize(parsers, command_line_parameters=None, skips=[]): zt_score_directories=[os.path.join(args.temp_directory, protocol, s) for s in args.zt_directories], compressed_extension='.tar.bz2' if args.write_compressed_score_files else '', default_extension='.hdf5', + zt_norm = args.zt_norm ) return args -- GitLab