diff --git a/bob/bio/base/database/database.py b/bob/bio/base/database/database.py index 5bfb7978bacf2a0c9f4daf3b66065be9bd1bfd51..2ab5156477c2803242ece20cb1bb8ad27640eabc 100644 --- a/bob/bio/base/database/database.py +++ b/bob/bio/base/database/database.py @@ -370,7 +370,7 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): """ return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol)) - def all_files(self, groups=None): + def all_files(self, groups=None, **kwargs): """all_files(groups=None) -> files Returns all files of the database, respecting the current protocol. @@ -382,6 +382,8 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)): The groups to get the data for. If ``None``, data for all groups is returned. + kwargs: ignored + **Returns:** files : [:py:class:`bob.bio.base.database.BioFile`] @@ -640,7 +642,7 @@ class ZTBioDatabase(BioDatabase): """ raise NotImplementedError("This function must be implemented in your derived class.") - def all_files(self, groups=['dev']): + def all_files(self, groups=['dev'], add_zt_files=True): """all_files(groups=None) -> files Returns all files of the database, including those for ZT norm, respecting the current protocol. @@ -652,6 +654,9 @@ class ZTBioDatabase(BioDatabase): The groups to get the data for. If ``None``, data for all groups is returned. + add_zt_files: bool + If set (the default), files for ZT score normalization are added. + **Returns:** files : [:py:class:`bob.bio.base.database.BioFile`] @@ -660,11 +665,12 @@ class ZTBioDatabase(BioDatabase): files = self.objects(protocol=self.protocol, groups=groups, **self.all_files_options) # add all files that belong to the ZT-norm - for group in groups: - if group == 'world': - continue - files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None) - files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options) + if add_zt_files: + for group in groups: + if group == 'world': + continue + files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None) + files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options) return self.sort(files) @abc.abstractmethod diff --git a/bob/bio/base/database/filelist/query.py b/bob/bio/base/database/filelist/query.py index dcaad9320fa0e0acca892cf75fb20a096de6b4c6..3d3c4f115f29f1c0fd32192ebe4cea30dd6602bc 100644 --- a/bob/bio/base/database/filelist/query.py +++ b/bob/bio/base/database/filelist/query.py @@ -198,13 +198,15 @@ class FileListBioDatabase(ZTBioDatabase): def _make_bio(self, files): return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files] - def all_files(self, groups=['dev']): + def all_files(self, groups=['dev'], add_zt_files=True): files = self.objects(groups, self.protocol, **self.all_files_options) # add all files that belong to the ZT-norm for group in groups: if group == 'world': continue - if self.implements_zt(self.protocol, group): + if add_zt_files: + if not self.implements_zt(self.protocol, group): + raise ValueError("ZT score files are requested, but no such files are defined in group %s for protocol %s", group, self.protocol) files += self.tobjects(group, self.protocol) files += self.zobjects(group, self.protocol, **self.z_probe_options) return self.sort(self._make_bio(files)) diff --git a/bob/bio/base/test/test_scripts.py b/bob/bio/base/test/test_scripts.py index 3d3b6e14f4cff46489b16bfe92e0b491bfc6515e..7b647065b95b4f29f7982e502649992cc2f0a809 100644 --- a/bob/bio/base/test/test_scripts.py +++ b/bob/bio/base/test/test_scripts.py @@ -21,24 +21,26 @@ regenerate_reference = False dummy_dir = pkg_resources.resource_filename('bob.bio.base', 'test/dummy') data_dir = pkg_resources.resource_filename('bob.bio.base', 'test/data') -def _verify(parameters, test_dir, sub_dir, ref_modifier="", score_modifier=('scores',''), counts=3): +def _verify(parameters, test_dir, sub_dir, ref_modifier="", score_modifier=('scores',''), counts=3, check_zt=True): from bob.bio.base.script.verify import main try: main(parameters) + Range = (0,1) if check_zt else (0,) + # assert that the score file exists score_files = [os.path.join(test_dir, sub_dir, 'Default', norm, '%s-dev%s'%score_modifier) for norm in ('nonorm', 'ztnorm')] - assert os.path.exists(score_files[0]), "Score file %s does not exist" % score_files[0] - assert os.path.exists(score_files[1]), "Score file %s does not exist" % score_files[1] + for i in Range: + assert os.path.exists(score_files[i]), "Score file %s does not exist" % score_files[i] # also assert that the scores are still the same -- though they have no real meaning reference_files = [os.path.join(data_dir, 'scores-%s%s-dev'%(norm, ref_modifier)) for norm in ('nonorm', 'ztnorm')] if regenerate_reference: - for i in (0,1): + for i in Range: shutil.copy(score_files[i], reference_files[i]) - for i in (0,1): + for i in Range: d = [] # read reference and new data for score_file in (score_files[i], reference_files[i]): @@ -95,6 +97,23 @@ def test_verify_algorithm_noprojection(): _verify(parameters, test_dir, 'algorithm_noprojection') +def test_verify_no_ztnorm(): + test_dir = tempfile.mkdtemp(prefix='bobtest_') + # define dummy parameters + parameters = [ + '-d', os.path.join(dummy_dir, 'database.py'), + '-p', os.path.join(dummy_dir, 'preprocessor.py'), + '-e', os.path.join(dummy_dir, 'extractor.py'), + '-a', os.path.join(dummy_dir, 'algorithm_noprojection.py'), + '-vs', 'test_nozt', + '--temp-directory', test_dir, + '--result-directory', test_dir + ] + + _verify(parameters, test_dir, 'test_nozt', check_zt=False) + + + def test_verify_resources(): test_dir = tempfile.mkdtemp(prefix='bobtest_') # define dummy parameters diff --git a/bob/bio/base/tools/FileSelector.py b/bob/bio/base/tools/FileSelector.py index 0870ee71ff27160b37251a3e17f7b38d5e60f9f7..3f1a9baead720035be523c5970bc99251cf386a0 100644 --- a/bob/bio/base/tools/FileSelector.py +++ b/bob/bio/base/tools/FileSelector.py @@ -68,7 +68,8 @@ class FileSelector(object): score_directories, zt_score_directories = None, default_extension = '.hdf5', - compressed_extension = '' + compressed_extension = '', + zt_norm = False ): """Initialize the file selector object with the current configuration.""" @@ -89,6 +90,7 @@ class FileSelector(object): 'extracted' : extracted_directory, 'projected' : projected_directory } + self.zt_norm = zt_norm def uses_probe_file_sets(self): @@ -108,7 +110,7 @@ class FileSelector(object): ### List of files that will be used for all files def original_data_list(self, groups = None): """Returns the list of original ``BioFile`` objects that can be used for preprocessing.""" - return self.database.all_files(groups=groups) + return self.database.all_files(groups=groups,add_zt_files=self.zt_norm) def original_directory_and_extension(self): """Returns the directory and extension of the original files.""" @@ -116,7 +118,7 @@ class FileSelector(object): def annotation_list(self, groups = None): """Returns the list of annotations objects.""" - return self.database.all_files(groups=groups) + return self.database.all_files(groups=groups,add_zt_files=self.zt_norm) def get_annotations(self, annotation_file): """Returns the annotations of the given file.""" @@ -124,15 +126,15 @@ class FileSelector(object): def preprocessed_data_list(self, groups = None): """Returns the list of preprocessed data files.""" - return self.get_paths(self.database.all_files(groups=groups), "preprocessed") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "preprocessed") def feature_list(self, groups = None): """Returns the list of extracted feature files.""" - return self.get_paths(self.database.all_files(groups=groups), "extracted") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "extracted") def projected_list(self, groups = None): """Returns the list of projected feature files.""" - return self.get_paths(self.database.all_files(groups=groups), "projected") + return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "projected") ### Training lists diff --git a/bob/bio/base/tools/command_line.py b/bob/bio/base/tools/command_line.py index eb2dc4312f75fee8959d3efbadacec5a31a1b5fa..431947f1b97c0ca638a528e8764d042b8892ea66 100644 --- a/bob/bio/base/tools/command_line.py +++ b/bob/bio/base/tools/command_line.py @@ -460,6 +460,7 @@ def initialize(parsers, command_line_parameters=None, skips=[]): zt_score_directories=[os.path.join(args.temp_directory, protocol, s) for s in args.zt_directories], compressed_extension='.tar.bz2' if args.write_compressed_score_files else '', default_extension='.hdf5', + zt_norm = args.zt_norm ) return args