From 1a500e1591f73e8e392c09af3dc66fb6adfc002f Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Tue, 7 Nov 2017 10:49:18 +0100
Subject: [PATCH] Make the code more DRY

---
 bob/bio/base/extractor/stacks.py | 69 +++++++++++++++++++-------------
 1 file changed, 41 insertions(+), 28 deletions(-)

diff --git a/bob/bio/base/extractor/stacks.py b/bob/bio/base/extractor/stacks.py
index 6bc3cb5f..0021d4ad 100644
--- a/bob/bio/base/extractor/stacks.py
+++ b/bob/bio/base/extractor/stacks.py
@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
     """Base class for SequentialExtractor and ParallelExtractor. This class is
     not meant to be used directly."""
 
-    def get_attributes(self, processors):
+    @staticmethod
+    def get_attributes(processors):
         requires_training = any(p.requires_training for p in processors)
         split_training_data_by_client = any(p.split_training_data_by_client for
                                             p in processors)
         min_extractor_file_size = min(p.min_extractor_file_size for p in
                                       processors)
-        min_feature_file_size = min(
-            p.min_feature_file_size for p in processors)
+        min_feature_file_size = min(p.min_feature_file_size for p in
+                                    processors)
         return (requires_training, split_training_data_by_client,
                 min_extractor_file_size, min_feature_file_size)
 
@@ -23,38 +24,54 @@ class MultipleExtractor(Extractor):
         return groups
 
     def train_one(self, e, training_data, extractor_file, apply=False):
+        """Trains one extractor and optionally applies the extractor on the
+        training data after training.
+
+        Parameters
+        ----------
+        e : :any:`Extractor`
+            The extractor to train. The extractor should be able to save itself
+            in an opened hdf5 file.
+        training_data : [object] or [[object]]
+            The data to be used for training.
+        extractor_file : :any:`bob.io.base.HDF5File`
+            The opened hdf5 file to save the trained extractor inside.
+        apply : :obj:`bool`, optional
+            If ``True``, the extractor is applied to the training data after it
+            is trained and the data is returned.
+
+        Returns
+        -------
+        None or [object] or [[object]]
+            Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
+            transformed ``training_data``.
+        """
         if not e.requires_training:
-            if not apply:
-                return
-            if self.split_training_data_by_client:
-                training_data = [[e(d) for d in datalist]
-                                 for datalist in training_data]
-            else:
-                training_data = [e(d) for d in training_data]
+            # do nothing since e does not require training!
+            pass
         # if any of the extractors require splitting the data, the
         # split_training_data_by_client is True.
         elif e.split_training_data_by_client:
             e.train(training_data, extractor_file)
-            if not apply:
-                return
-            training_data = [[e(d) for d in datalist]
-                             for datalist in training_data]
         # when no extractor needs splitting
         elif not self.split_training_data_by_client:
             e.train(training_data, extractor_file)
-            if not apply:
-                return
-            training_data = [e(d) for d in training_data]
         # when e here wants it flat but the data is split
         else:
             # make training_data flat
-            aligned_training_data = [d for datalist in training_data for d in
-                                     datalist]
-            e.train(aligned_training_data, extractor_file)
-            if not apply:
-                return
+            flat_training_data = [d for datalist in training_data for d in
+                                  datalist]
+            e.train(flat_training_data, extractor_file)
+
+        if not apply:
+            return
+
+        # prepare the training data for the next extractor
+        if self.split_training_data_by_client:
             training_data = [[e(d) for d in datalist]
                              for datalist in training_data]
+        else:
+            training_data = [e(d) for d in training_data]
         return training_data
 
     def load(self, extractor_file):
@@ -62,8 +79,7 @@ class MultipleExtractor(Extractor):
             groups = self.get_extractor_groups()
             for e, group in zip(self.processors, groups):
                 f.cd(group)
-                if e.requires_training:
-                    e.load(f)
+                e.load(f)
                 f.cd('..')
 
 
@@ -112,10 +128,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
         with HDF5File(extractor_file, 'w') as f:
             groups = self.get_extractor_groups()
             for i, (e, group) in enumerate(zip(self.processors, groups)):
-                if i == len(self.processors) - 1:
-                    apply = False
-                else:
-                    apply = True
+                apply = i != len(self.processors) - 1
                 f.create_group(group)
                 f.cd(group)
                 training_data = self.train_one(e, training_data, f,
-- 
GitLab