From 036e437585a39c8e0df502d3a57de5958b7ff06d Mon Sep 17 00:00:00 2001
From: Amir Mohammadi <183.amir@gmail.com>
Date: Mon, 3 Apr 2017 17:17:11 +0200
Subject: [PATCH] Use preprocessing and extraction logic from bob.bio.base

---
 bob/pad/base/test/dummy/extractor.py |  3 +
 bob/pad/base/tools/FileSelector.py   | 30 +++++----
 bob/pad/base/tools/algorithm.py      |  2 +-
 bob/pad/base/tools/extractor.py      | 95 +++++++++++++++-------------
 bob/pad/base/tools/preprocessor.py   | 85 +++++++++++++------------
 5 files changed, 119 insertions(+), 96 deletions(-)

diff --git a/bob/pad/base/test/dummy/extractor.py b/bob/pad/base/test/dummy/extractor.py
index 8e65a74..39959a5 100644
--- a/bob/pad/base/test/dummy/extractor.py
+++ b/bob/pad/base/test/dummy/extractor.py
@@ -18,4 +18,7 @@ class DummyExtractor(Extractor):
         assert (data in _data)
         return data + 1.0
 
+    def train(self, training_data, extractor_file):
+        pass
+
 extractor = DummyExtractor()
diff --git a/bob/pad/base/tools/FileSelector.py b/bob/pad/base/tools/FileSelector.py
index 201877f..d011e78 100644
--- a/bob/pad/base/tools/FileSelector.py
+++ b/bob/pad/base/tools/FileSelector.py
@@ -99,17 +99,22 @@ class FileSelector(object):
             return [realpaths, attackpaths]
 
     # List of files that will be used for all files
-    def original_data_list(self, groups=None):
-        """Returns the the joint list of original (real and attack) file names."""
-        return self.database.original_file_names(self.database.all_files(groups=groups))
-
-    def original_data_list_files(self, groups=None):
-        """Returns the joint list of original (real and attack) data files that can be used for preprocessing."""
+    def original_data_list(self, groups = None):
+        """Returns the list of original ``PadFile`` objects that can be used for preprocessing."""
         files = self.database.all_files(groups=groups)
         if len(files) != 2:
             fileset = files
         else:
             fileset = files[0]+files[1]
+        return fileset
+
+    def original_directory_and_extension(self):
+        """Returns the directory and extension of the original files."""
+        return self.database.original_directory, self.database.original_extension
+
+    def original_data_list_files(self, groups=None):
+        """Returns the joint list of original (real and attack) data files that can be used for preprocessing."""
+        fileset = self.original_data_list(groups=groups)
         return fileset, self.database.original_directory, self.database.original_extension
 
     def preprocessed_data_list(self, groups=None):
@@ -125,12 +130,15 @@ class FileSelector(object):
         return self.get_paths(self.database.all_files(groups=groups), "projected")
 
     # Training lists
-    def training_list(self, directory_type, step):
-        """Returns the tuple of lists (real, attacks) of features that should be used for projector training.
-        The directory_type might be any of 'preprocessed', 'extracted', or 'projected'.
-        The step might by any of 'train_extractor', 'train_projector', or 'train_enroller'.
+    def training_list(self, directory_type, step, combined=False):
+        """
+        Returns a list of lists (real, attacks) or just list of all real and
+        attack features depending on combined that should be used for projector
+        training. The directory_type might be any of 'preprocessed',
+        'extracted', or 'projected'. The step might by any of
+        'train_extractor', 'train_projector', or 'train_enroller'.
         """
-        return self.get_paths(self.database.training_files(step), directory_type, False)
+        return self.get_paths(self.database.training_files(step), directory_type, combined)
 
     def toscore_objects(self, group):
         """Returns the File objects used to compute the raw scores."""
diff --git a/bob/pad/base/tools/algorithm.py b/bob/pad/base/tools/algorithm.py
index 0bd4cc3..6a52b9b 100644
--- a/bob/pad/base/tools/algorithm.py
+++ b/bob/pad/base/tools/algorithm.py
@@ -52,7 +52,7 @@ def train_projector(algorithm, extractor, allow_missing_files=False, force=False
         # train projector
         logger.info("- Projection: loading training data")
         train_files = fs.training_list('extracted', 'train_projector')
-        train_features = read_features(train_files, extractor)
+        train_features = read_features(train_files, extractor, True, allow_missing_files)
         logger.info("- Projection: training projector '%s' using %d training files: ", fs.projector_file,
                     len(train_files))
 
diff --git a/bob/pad/base/tools/extractor.py b/bob/pad/base/tools/extractor.py
index f67783a..b48872e 100644
--- a/bob/pad/base/tools/extractor.py
+++ b/bob/pad/base/tools/extractor.py
@@ -16,6 +16,7 @@ logger = logging.getLogger("bob.pad.base")
 from .FileSelector import FileSelector
 from bob.bio.base import utils
 from .preprocessor import read_preprocessed_data
+from bob.bio.base.tools.extractor import read_features
 
 
 def train_extractor(extractor, preprocessor, allow_missing_files=False, force=False):
@@ -41,7 +42,6 @@ def train_extractor(extractor, preprocessor, allow_missing_files=False, force=Fa
     force : bool
       If given, the extractor file is regenerated, even if it already exists.
     """
-
     if not extractor.requires_training:
         logger.warn(
             "The train_extractor function should not have been called, since the extractor does not need training.")
@@ -52,17 +52,25 @@ def train_extractor(extractor, preprocessor, allow_missing_files=False, force=Fa
     # the file to write
     if utils.check_file(fs.extractor_file, force,
                         extractor.min_extractor_file_size):
-        logger.info("- Extraction: extractor '%s' already exists.", fs.extractor_file)
+        logger.info("- Extraction: extractor '%s' already exists.",
+                    fs.extractor_file)
     else:
         bob.io.base.create_directories_safe(os.path.dirname(fs.extractor_file))
         # read training files
-        train_files = fs.training_list('preprocessed', 'train_extractor')
-        train_data = read_preprocessed_data(train_files, preprocessor)
-        logger.info("- Extraction: training extractor '%s' using %d training files:", fs.extractor_file,
-                    len(train_files))
+        train_files = fs.training_list(
+            'preprocessed', 'train_extractor', combined=~extractor.split_training_data_by_client)
+        train_data = read_preprocessed_data(
+            train_files, preprocessor, extractor.split_training_data_by_client, allow_missing_files)
+        if extractor.split_training_data_by_client:
+            logger.info("- Extraction: training extractor '%s' using %d classes:",
+                        fs.extractor_file, len(train_files))
+        else:
+            logger.info("- Extraction: training extractor '%s' using %d training files:",
+                        fs.extractor_file, len(train_files))
         # train model
         extractor.train(train_data, fs.extractor_file)
 
+
 def extract(extractor, preprocessor, groups=None, indices=None, allow_missing_files=False, force=False):
     """Extracts features from the preprocessed data using the given extractor.
 
@@ -87,6 +95,9 @@ def extract(extractor, preprocessor, groups=None, indices=None, allow_missing_fi
       If specified, only the features for the given index range ``range(begin, end)`` should be extracted.
       This is usually given, when parallel threads are executed.
 
+    allow_missing_files : bool
+      If set to ``True``, preprocessed data files that are not found are silently ignored.
+
     force : bool
       If given, files are regenerated, even if they already exist.
     """
@@ -97,7 +108,7 @@ def extract(extractor, preprocessor, groups=None, indices=None, allow_missing_fi
     feature_files = fs.feature_list(groups=groups)
 
     # select a subset of indices to iterate
-    if indices != None:
+    if indices is not None:
         index_range = range(indices[0], indices[1])
         logger.info("- Extraction: splitting of index range %s" % str(indices))
     else:
@@ -106,44 +117,42 @@ def extract(extractor, preprocessor, groups=None, indices=None, allow_missing_fi
     logger.info("- Extraction: extracting %d features from directory '%s' to directory '%s'", len(index_range),
                 fs.directories['preprocessed'], fs.directories['extracted'])
     for i in index_range:
-        data_file = str(data_files[i])
-        feature_file = str(feature_files[i])
-
-        if not utils.check_file(feature_file, force, 1000):
+        data_file = data_files[i]
+        feature_file = feature_files[i]
+
+        if not os.path.exists(data_file) and preprocessor.writes_data:
+            if allow_missing_files:
+                logger.debug(
+                    "... Cannot find preprocessed data file %s; skipping", data_file)
+                continue
+            else:
+                logger.error(
+                    "Cannot find preprocessed data file %s", data_file)
+
+        if not utils.check_file(feature_file, force,
+                                extractor.min_feature_file_size):
+            logger.debug(
+                "... Extracting features for data file '%s'", data_file)
+            # create output directory before reading the data file (is
+            # sometimes required, when relative directories are specified,
+            # especially, including a .. somewhere)
+            bob.io.base.create_directories_safe(os.path.dirname(feature_file))
             # load data
             data = preprocessor.read_data(data_file)
             # extract feature
-            try:
-                logger.info("- Extraction: extracting from file: %s", data_file)
-                feature = extractor(data)
-            except ValueError:
-                logger.warn("WARNING: empty data in file %s", data_file)
-                feature = 0
-            # write feature
-            if feature is not None:
-                bob.io.base.create_directories_safe(os.path.dirname(feature_file))
-                extractor.write_feature(feature, feature_file)
-
-
-def read_features(file_names, extractor):
-    """read_features(file_names, extractor) -> extracted
+            feature = extractor(data)
 
-    Reads the extracted features from ``file_names`` using the given ``extractor``.
+            if feature is None:
+                if allow_missing_files:
+                    logger.debug(
+                        "... Feature extraction for data file %s failed; skipping", data_file)
+                    continue
+                else:
+                    raise RuntimeError(
+                        "Feature extraction  of file '%s' was not successful", data_file)
 
-    **Parameters:**
-
-    file_names : [[str], [str]]
-      A list of lists of file names (real, attack) to be read.
-
-    extractor : py:class:`bob.bio.base.extractor.Extractor` or derived
-      The extractor, used for reading the extracted features.
-
-    **Returns:**
-
-    extracted : [object] or [[object]]
-      The list of extracted features, in the same order as in the ``file_names``.
-    """
-    real_files = file_names[0]
-    attack_files = file_names[1]
-    return [[extractor.read_feature(str(f)) for f in real_files],
-            [extractor.read_feature(str(f)) for f in attack_files]]
+            # write feature
+            extractor.write_feature(feature, feature_file)
+        else:
+            logger.debug(
+                "... Skipping preprocessed data '%s' since feature file '%s' exists", data_file, feature_file)
diff --git a/bob/pad/base/tools/preprocessor.py b/bob/pad/base/tools/preprocessor.py
index 4df9476..0a4784f 100644
--- a/bob/pad/base/tools/preprocessor.py
+++ b/bob/pad/base/tools/preprocessor.py
@@ -11,10 +11,11 @@ import os
 
 import logging
 
-logger = logging.getLogger("bob.pad.base")
-
 from .FileSelector import FileSelector
 from bob.bio.base import utils
+from bob.bio.base.tools import read_preprocessed_data
+
+logger = logging.getLogger("bob.pad.base")
 
 
 def preprocess(preprocessor, groups=None, indices=None, allow_missing_files=False, force=False):
@@ -26,7 +27,7 @@ def preprocess(preprocessor, groups=None, indices=None, allow_missing_files=Fals
 
     **Parameters:**
 
-    preprocessor : py:class:`bob.bio.base.preprocessor.Preprocessor` or derived.
+    preprocessor : py:class:`bob.bio.base.preprocessor.Preprocessor` or derived
       The preprocessor, which should be applied to all data.
 
     groups : some of ``('train', 'dev', 'eval')`` or ``None``
@@ -36,42 +37,58 @@ def preprocess(preprocessor, groups=None, indices=None, allow_missing_files=Fals
       If specified, only the data for the given index range ``range(begin, end)`` should be preprocessed.
       This is usually given, when parallel threads are executed.
 
+    allow_missing_files : bool
+      If set to ``True``, files for which the preprocessor returns ``None`` are silently ignored.
+
     force : bool
       If given, files are regenerated, even if they already exist.
     """
+    if not preprocessor.writes_data:
+        # The preprocessor does not write anything, so no need to call it
+        logger.info(
+            "Skipping preprocessing as preprocessor does not write any data")
+        return
+
     # the file selector object
     fs = FileSelector.instance()
 
     # get the file lists
-    data_files, original_directory, original_extension = fs.original_data_list_files(groups=groups)
+    data_files = fs.original_data_list(groups=groups)
+    original_directory, original_extension = fs.original_directory_and_extension()
     preprocessed_data_files = fs.preprocessed_data_list(groups=groups)
 
-    # read annotation files
-    annotation_list = fs.annotation_list(groups=groups)
-
     # select a subset of keys to iterate
     if indices is not None:
         index_range = range(indices[0], indices[1])
-        logger.info("- Preprocessing: splitting of index range %s", str(indices))
+        logger.info(
+            "- Preprocessing: splitting of index range %s", str(indices))
     else:
         index_range = range(len(data_files))
 
-    logger.info("- Preprocessing: processing %d data files from directory '%s' to directory '%s'", len(index_range),
-                fs.directories['original'], fs.directories['preprocessed'])
+    logger.info("- Preprocessing: processing %d data files from directory '%s' to directory '%s'",
+                len(index_range), fs.directories['original'], fs.directories['preprocessed'])
 
+    # read annotation files
+    annotation_list = fs.annotation_list(groups=groups)
 
     # iterate over the selected files
     for i in index_range:
-        preprocessed_data_file = str(preprocessed_data_files[i])
+        preprocessed_data_file = preprocessed_data_files[i]
         file_object = data_files[i]
-        file_name = file_object.make_path(original_directory, original_extension)
+        file_name = file_object.make_path(
+            original_directory, original_extension)
 
         # check for existence
-        if not utils.check_file(preprocessed_data_file, force, 1000):
-            logger.info("... Processing original data file '%s'", file_name)
-            data = preprocessor.read_original_data(file_object, original_directory, original_extension)
-            # create output directory before reading the data file (is sometimes required, when relative directories are specified, especially, including a .. somewhere)
-            bob.io.base.create_directories_safe(os.path.dirname(preprocessed_data_file))
+        if not utils.check_file(preprocessed_data_file, force,
+                                preprocessor.min_preprocessed_file_size):
+            logger.debug("... Processing original data file '%s'", file_name)
+            data = preprocessor.read_original_data(
+                file_object, original_directory, original_extension)
+            # create output directory before reading the data file (is
+            # sometimes required, when relative directories are specified,
+            # especially, including a .. somewhere)
+            bob.io.base.create_directories_safe(
+                os.path.dirname(preprocessed_data_file))
 
             # get the annotations; might be None
             annotations = fs.get_annotations(annotation_list[i])
@@ -79,31 +96,17 @@ def preprocess(preprocessor, groups=None, indices=None, allow_missing_files=Fals
             # call the preprocessor
             preprocessed_data = preprocessor(data, annotations)
             if preprocessed_data is None:
-                logger.error("Preprocessing of file '%s' was not successful", file_name)
-                continue
+                if allow_missing_files:
+                    logger.debug(
+                        "... Processing original data file '%s' was not successful", file_name)
+                    continue
+                else:
+                     raise RuntimeError(
+                        "Preprocessing of file '%s' was not successful", file_name)
 
             # write the data
             preprocessor.write_data(preprocessed_data, preprocessed_data_file)
 
-
-def read_preprocessed_data(file_names, preprocessor):
-    """read_preprocessed_data(file_names, preprocessor, split_by_client = False) -> preprocessed
-
-    Reads the preprocessed data from ``file_names`` using the given preprocessor.
-    If ``split_by_client`` is set to ``True``, it is assumed that the ``file_names`` are already sorted by client.
-
-    **Parameters:**
-
-    file_names : [str] or [[str]]
-      A list of names of files to be read.
-      If ``split_by_client = True``, file names are supposed to be split into groups.
-
-    preprocessor : py:class:`bob.bio.base.preprocessor.Preprocessor` or derived
-      The preprocessor, which can read the preprocessed data.
-
-    **Returns:**
-
-    preprocessed : [object] or [[object]]
-      The list of preprocessed data, in the same order as in the ``file_names``.
-    """
-    return [preprocessor.read_data(str(f)) for f in file_names]
+        else:
+            logger.debug("... Skipping original data file '%s' since preprocessed data '%s' exists",
+                         file_name, preprocessed_data_file)
-- 
GitLab