Skip to content
Snippets Groups Projects
Commit 1a500e15 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Make the code more DRY

parent 31bdc2e9
No related branches found
No related tags found
1 merge request!112SequentialExtractor: Apply extractor on training data always when apply=True
Pipeline #
......@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
"""Base class for SequentialExtractor and ParallelExtractor. This class is
not meant to be used directly."""
def get_attributes(self, processors):
@staticmethod
def get_attributes(processors):
requires_training = any(p.requires_training for p in processors)
split_training_data_by_client = any(p.split_training_data_by_client for
p in processors)
min_extractor_file_size = min(p.min_extractor_file_size for p in
processors)
min_feature_file_size = min(
p.min_feature_file_size for p in processors)
min_feature_file_size = min(p.min_feature_file_size for p in
processors)
return (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size)
......@@ -23,38 +24,54 @@ class MultipleExtractor(Extractor):
return groups
def train_one(self, e, training_data, extractor_file, apply=False):
"""Trains one extractor and optionally applies the extractor on the
training data after training.
Parameters
----------
e : :any:`Extractor`
The extractor to train. The extractor should be able to save itself
in an opened hdf5 file.
training_data : [object] or [[object]]
The data to be used for training.
extractor_file : :any:`bob.io.base.HDF5File`
The opened hdf5 file to save the trained extractor inside.
apply : :obj:`bool`, optional
If ``True``, the extractor is applied to the training data after it
is trained and the data is returned.
Returns
-------
None or [object] or [[object]]
Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
transformed ``training_data``.
"""
if not e.requires_training:
if not apply:
return
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist]
for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
# do nothing since e does not require training!
pass
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
elif e.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [[e(d) for d in datalist]
for datalist in training_data]
# when no extractor needs splitting
elif not self.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [e(d) for d in training_data]
# when e here wants it flat but the data is split
else:
# make training_data flat
aligned_training_data = [d for datalist in training_data for d in
datalist]
e.train(aligned_training_data, extractor_file)
if not apply:
return
flat_training_data = [d for datalist in training_data for d in
datalist]
e.train(flat_training_data, extractor_file)
if not apply:
return
# prepare the training data for the next extractor
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist]
for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
return training_data
def load(self, extractor_file):
......@@ -62,8 +79,7 @@ class MultipleExtractor(Extractor):
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
f.cd(group)
if e.requires_training:
e.load(f)
e.load(f)
f.cd('..')
......@@ -112,10 +128,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for i, (e, group) in enumerate(zip(self.processors, groups)):
if i == len(self.processors) - 1:
apply = False
else:
apply = True
apply = i != len(self.processors) - 1
f.create_group(group)
f.cd(group)
training_data = self.train_one(e, training_data, f,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment