From 87b41c41fa90ccccb6c5a6e886ce81f58fa3d776 Mon Sep 17 00:00:00 2001
From: Manuel Gunther <siebenkopf@googlemail.com>
Date: Tue, 4 Apr 2017 19:23:53 -0600
Subject: [PATCH] Implemented skipping of ZT files if not wanted

---
 bob/bio/base/database/database.py       | 20 +++++++++++++-------
 bob/bio/base/database/filelist/query.py |  6 ++++--
 bob/bio/base/tools/FileSelector.py      | 14 ++++++++------
 bob/bio/base/tools/command_line.py      |  1 +
 4 files changed, 26 insertions(+), 15 deletions(-)

diff --git a/bob/bio/base/database/database.py b/bob/bio/base/database/database.py
index 5bfb7978..2ab51564 100644
--- a/bob/bio/base/database/database.py
+++ b/bob/bio/base/database/database.py
@@ -370,7 +370,7 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
         """
         return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol))
 
-    def all_files(self, groups=None):
+    def all_files(self, groups=None, **kwargs):
         """all_files(groups=None) -> files
 
         Returns all files of the database, respecting the current protocol.
@@ -382,6 +382,8 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
           The groups to get the data for.
           If ``None``, data for all groups is returned.
 
+        kwargs: ignored
+
         **Returns:**
 
         files : [:py:class:`bob.bio.base.database.BioFile`]
@@ -640,7 +642,7 @@ class ZTBioDatabase(BioDatabase):
         """
         raise NotImplementedError("This function must be implemented in your derived class.")
 
-    def all_files(self, groups=['dev']):
+    def all_files(self, groups=['dev'], add_zt_files=True):
         """all_files(groups=None) -> files
 
         Returns all files of the database, including those for ZT norm, respecting the current protocol.
@@ -652,6 +654,9 @@ class ZTBioDatabase(BioDatabase):
           The groups to get the data for.
           If ``None``, data for all groups is returned.
 
+        add_zt_files: bool
+          If set (the default), files for ZT score normalization are added.
+
         **Returns:**
 
         files : [:py:class:`bob.bio.base.database.BioFile`]
@@ -660,11 +665,12 @@ class ZTBioDatabase(BioDatabase):
         files = self.objects(protocol=self.protocol, groups=groups, **self.all_files_options)
 
         # add all files that belong to the ZT-norm
-        for group in groups:
-            if group == 'world':
-                continue
-            files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None)
-            files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options)
+        if add_zt_files:
+            for group in groups:
+                if group == 'world':
+                    continue
+                files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None)
+                files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options)
         return self.sort(files)
 
     @abc.abstractmethod
diff --git a/bob/bio/base/database/filelist/query.py b/bob/bio/base/database/filelist/query.py
index dcaad932..3d3c4f11 100644
--- a/bob/bio/base/database/filelist/query.py
+++ b/bob/bio/base/database/filelist/query.py
@@ -198,13 +198,15 @@ class FileListBioDatabase(ZTBioDatabase):
     def _make_bio(self, files):
         return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files]
 
-    def all_files(self, groups=['dev']):
+    def all_files(self, groups=['dev'], add_zt_files=True):
         files = self.objects(groups, self.protocol, **self.all_files_options)
         # add all files that belong to the ZT-norm
         for group in groups:
             if group == 'world':
                 continue
-            if self.implements_zt(self.protocol, group):
+            if add_zt_files:
+                if not self.implements_zt(self.protocol, group):
+                    raise ValueError("ZT score files are requested, but no such files are defined in group %s for protocol %s", group, self.protocol)
                 files += self.tobjects(group, self.protocol)
                 files += self.zobjects(group, self.protocol, **self.z_probe_options)
         return self.sort(self._make_bio(files))
diff --git a/bob/bio/base/tools/FileSelector.py b/bob/bio/base/tools/FileSelector.py
index 0870ee71..3f1a9bae 100644
--- a/bob/bio/base/tools/FileSelector.py
+++ b/bob/bio/base/tools/FileSelector.py
@@ -68,7 +68,8 @@ class FileSelector(object):
     score_directories,
     zt_score_directories = None,
     default_extension = '.hdf5',
-    compressed_extension = ''
+    compressed_extension = '',
+    zt_norm = False
   ):
 
     """Initialize the file selector object with the current configuration."""
@@ -89,6 +90,7 @@ class FileSelector(object):
       'extracted'    : extracted_directory,
       'projected'    : projected_directory
     }
+    self.zt_norm = zt_norm
 
 
   def uses_probe_file_sets(self):
@@ -108,7 +110,7 @@ class FileSelector(object):
   ### List of files that will be used for all files
   def original_data_list(self, groups = None):
     """Returns the list of original ``BioFile`` objects that can be used for preprocessing."""
-    return self.database.all_files(groups=groups)
+    return self.database.all_files(groups=groups,add_zt_files=self.zt_norm)
 
   def original_directory_and_extension(self):
     """Returns the directory and extension of the original files."""
@@ -116,7 +118,7 @@ class FileSelector(object):
 
   def annotation_list(self, groups = None):
     """Returns the list of annotations objects."""
-    return self.database.all_files(groups=groups)
+    return self.database.all_files(groups=groups,add_zt_files=self.zt_norm)
 
   def get_annotations(self, annotation_file):
     """Returns the annotations of the given file."""
@@ -124,15 +126,15 @@ class FileSelector(object):
 
   def preprocessed_data_list(self, groups = None):
     """Returns the list of preprocessed data files."""
-    return self.get_paths(self.database.all_files(groups=groups), "preprocessed")
+    return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "preprocessed")
 
   def feature_list(self, groups = None):
     """Returns the list of extracted feature files."""
-    return self.get_paths(self.database.all_files(groups=groups), "extracted")
+    return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "extracted")
 
   def projected_list(self, groups = None):
     """Returns the list of projected feature files."""
-    return self.get_paths(self.database.all_files(groups=groups), "projected")
+    return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "projected")
 
 
   ### Training lists
diff --git a/bob/bio/base/tools/command_line.py b/bob/bio/base/tools/command_line.py
index 5ddd7150..2cdc11f9 100644
--- a/bob/bio/base/tools/command_line.py
+++ b/bob/bio/base/tools/command_line.py
@@ -460,6 +460,7 @@ def initialize(parsers, command_line_parameters=None, skips=[]):
         zt_score_directories=[os.path.join(args.temp_directory, protocol, s) for s in args.zt_directories],
         compressed_extension='.tar.bz2' if args.write_compressed_score_files else '',
         default_extension='.hdf5',
+    zt_norm = args.zt_norm
     )
 
     return args
-- 
GitLab