Commit 1a500e15 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Make the code more DRY

parent 31bdc2e9
Pipeline #13817 passed with stages
in 13 minutes and 25 seconds
......@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
"""Base class for SequentialExtractor and ParallelExtractor. This class is
not meant to be used directly."""
def get_attributes(self, processors):
@staticmethod
def get_attributes(processors):
requires_training = any(p.requires_training for p in processors)
split_training_data_by_client = any(p.split_training_data_by_client for
p in processors)
min_extractor_file_size = min(p.min_extractor_file_size for p in
processors)
min_feature_file_size = min(
p.min_feature_file_size for p in processors)
min_feature_file_size = min(p.min_feature_file_size for p in
processors)
return (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size)
......@@ -23,38 +24,54 @@ class MultipleExtractor(Extractor):
return groups
def train_one(self, e, training_data, extractor_file, apply=False):
"""Trains one extractor and optionally applies the extractor on the
training data after training.
Parameters
----------
e : :any:`Extractor`
The extractor to train. The extractor should be able to save itself
in an opened hdf5 file.
training_data : [object] or [[object]]
The data to be used for training.
extractor_file : :any:`bob.io.base.HDF5File`
The opened hdf5 file to save the trained extractor inside.
apply : :obj:`bool`, optional
If ``True``, the extractor is applied to the training data after it
is trained and the data is returned.
Returns
-------
None or [object] or [[object]]
Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
transformed ``training_data``.
"""
if not e.requires_training:
if not apply:
return
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist]
for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
# do nothing since e does not require training!
pass
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
elif e.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [[e(d) for d in datalist]
for datalist in training_data]
# when no extractor needs splitting
elif not self.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [e(d) for d in training_data]
# when e here wants it flat but the data is split
else:
# make training_data flat
aligned_training_data = [d for datalist in training_data for d in
datalist]
e.train(aligned_training_data, extractor_file)
if not apply:
return
flat_training_data = [d for datalist in training_data for d in
datalist]
e.train(flat_training_data, extractor_file)
if not apply:
return
# prepare the training data for the next extractor
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist]
for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
return training_data
def load(self, extractor_file):
......@@ -62,8 +79,7 @@ class MultipleExtractor(Extractor):
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
f.cd(group)
if e.requires_training:
e.load(f)
e.load(f)
f.cd('..')
......@@ -112,10 +128,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for i, (e, group) in enumerate(zip(self.processors, groups)):
if i == len(self.processors) - 1:
apply = False
else:
apply = True
apply = i != len(self.processors) - 1
f.create_group(group)
f.cd(group)
training_data = self.train_one(e, training_data, f,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment