Skip to content
Snippets Groups Projects
Commit 6fd872d2 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Make the code more WET

parent 31bdc2e9
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor): ...@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
"""Base class for SequentialExtractor and ParallelExtractor. This class is """Base class for SequentialExtractor and ParallelExtractor. This class is
not meant to be used directly.""" not meant to be used directly."""
def get_attributes(self, processors): @staticmethod
def get_attributes(processors):
requires_training = any(p.requires_training for p in processors) requires_training = any(p.requires_training for p in processors)
split_training_data_by_client = any(p.split_training_data_by_client for split_training_data_by_client = any(p.split_training_data_by_client for
p in processors) p in processors)
min_extractor_file_size = min(p.min_extractor_file_size for p in min_extractor_file_size = min(p.min_extractor_file_size for p in
processors) processors)
min_feature_file_size = min( min_feature_file_size = min(p.min_feature_file_size for p in
p.min_feature_file_size for p in processors) processors)
return (requires_training, split_training_data_by_client, return (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) min_extractor_file_size, min_feature_file_size)
...@@ -23,38 +24,54 @@ class MultipleExtractor(Extractor): ...@@ -23,38 +24,54 @@ class MultipleExtractor(Extractor):
return groups return groups
def train_one(self, e, training_data, extractor_file, apply=False): def train_one(self, e, training_data, extractor_file, apply=False):
"""Trains one extractor and optionally applies the extractor on the
training data after training.
Parameters
----------
e : :any:`Extractor`
The extractor to train. The extractor should be able to save itself
in an opened hdf5 file.
training_data : [object] or [[object]]
The data to be used for training.
extractor_file : :any:`bob.io.base.HDF5File`
The opened hdf5 file to save the trained extractor inside.
apply : :obj:`bool`, optional
If ``True``, the extractor is applied to the training data after it
is trained and the data is returned.
Returns
-------
None or [object] or [[object]]
Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
transformed ``training_data``.
"""
if not e.requires_training: if not e.requires_training:
if not apply: # do nothing since e does not require training!
return pass
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist]
for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
# if any of the extractors require splitting the data, the # if any of the extractors require splitting the data, the
# split_training_data_by_client is True. # split_training_data_by_client is True.
elif e.split_training_data_by_client: elif e.split_training_data_by_client:
e.train(training_data, extractor_file) e.train(training_data, extractor_file)
if not apply:
return
training_data = [[e(d) for d in datalist]
for datalist in training_data]
# when no extractor needs splitting # when no extractor needs splitting
elif not self.split_training_data_by_client: elif not self.split_training_data_by_client:
e.train(training_data, extractor_file) e.train(training_data, extractor_file)
if not apply:
return
training_data = [e(d) for d in training_data]
# when e here wants it flat but the data is split # when e here wants it flat but the data is split
else: else:
# make training_data flat # make training_data flat
aligned_training_data = [d for datalist in training_data for d in flat_training_data = [d for datalist in training_data for d in
datalist] datalist]
e.train(aligned_training_data, extractor_file) e.train(flat_training_data, extractor_file)
if not apply:
return if not apply:
return
# prepare the training data for the next extractor
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist] training_data = [[e(d) for d in datalist]
for datalist in training_data] for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
return training_data return training_data
def load(self, extractor_file): def load(self, extractor_file):
...@@ -62,8 +79,7 @@ class MultipleExtractor(Extractor): ...@@ -62,8 +79,7 @@ class MultipleExtractor(Extractor):
groups = self.get_extractor_groups() groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups): for e, group in zip(self.processors, groups):
f.cd(group) f.cd(group)
if e.requires_training: e.load(f)
e.load(f)
f.cd('..') f.cd('..')
...@@ -112,10 +128,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor): ...@@ -112,10 +128,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
with HDF5File(extractor_file, 'w') as f: with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups() groups = self.get_extractor_groups()
for i, (e, group) in enumerate(zip(self.processors, groups)): for i, (e, group) in enumerate(zip(self.processors, groups)):
if i == len(self.processors) - 1: apply = i != len(self.processors) - 1
apply = False
else:
apply = True
f.create_group(group) f.create_group(group)
f.cd(group) f.cd(group)
training_data = self.train_one(e, training_data, f, training_data = self.train_one(e, training_data, f,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment