Skip to content
Snippets Groups Projects

SequentialExtractor: Apply extractor on training data always when apply=True

Merged Amir MOHAMMADI requested to merge processors into master
All threads resolved!
Files
3
@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
"""Base class for SequentialExtractor and ParallelExtractor. This class is
not meant to be used directly."""
def get_attributes(self, processors):
@staticmethod
def get_attributes(processors):
requires_training = any(p.requires_training for p in processors)
split_training_data_by_client = any(p.split_training_data_by_client for
p in processors)
min_extractor_file_size = min(p.min_extractor_file_size for p in
processors)
min_feature_file_size = min(
p.min_feature_file_size for p in processors)
min_feature_file_size = min(p.min_feature_file_size for p in
processors)
return (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size)
@@ -23,32 +24,54 @@ class MultipleExtractor(Extractor):
return groups
def train_one(self, e, training_data, extractor_file, apply=False):
"""Trains one extractor and optionally applies the extractor on the
training data after training.
Parameters
----------
e : :any:`Extractor`
The extractor to train. The extractor should be able to save itself
in an opened hdf5 file.
training_data : [object] or [[object]]
The data to be used for training.
extractor_file : :any:`bob.io.base.HDF5File`
The opened hdf5 file to save the trained extractor inside.
apply : :obj:`bool`, optional
If ``True``, the extractor is applied to the training data after it
is trained and the data is returned.
Returns
-------
None or [object] or [[object]]
Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
transformed ``training_data``.
"""
if not e.requires_training:
return
# do nothing since e does not require training!
pass
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
if e.split_training_data_by_client:
elif e.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [[e(d) for d in datalist]
for datalist in training_data]
# when no extractor needs splitting
elif not self.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [e(d) for d in training_data]
# when e here wants it flat but the data is split
else:
# make training_data flat
aligned_training_data = [d for datalist in training_data for d in
datalist]
e.train(aligned_training_data, extractor_file)
if not apply:
return
flat_training_data = [d for datalist in training_data for d in
datalist]
e.train(flat_training_data, extractor_file)
if not apply:
return
# prepare the training data for the next extractor
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist]
for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
return training_data
def load(self, extractor_file):
@@ -104,10 +127,12 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
for i, (e, group) in enumerate(zip(self.processors, groups)):
apply = i != len(self.processors) - 1
f.create_group(group)
f.cd(group)
training_data = self.train_one(e, training_data, f, apply=True)
training_data = self.train_one(e, training_data, f,
apply=apply)
f.cd('..')
def read_feature(self, feature_file):
Loading