diff --git a/bob/bio/base/extractor/stacks.py b/bob/bio/base/extractor/stacks.py index aae378385d0b4ebeacbc080f1462e73056fae4d7..0021d4ad4b78d108a3cbe437d5bcec8e7b1523e6 100644 --- a/bob/bio/base/extractor/stacks.py +++ b/bob/bio/base/extractor/stacks.py @@ -7,14 +7,15 @@ class MultipleExtractor(Extractor): """Base class for SequentialExtractor and ParallelExtractor. This class is not meant to be used directly.""" - def get_attributes(self, processors): + @staticmethod + def get_attributes(processors): requires_training = any(p.requires_training for p in processors) split_training_data_by_client = any(p.split_training_data_by_client for p in processors) min_extractor_file_size = min(p.min_extractor_file_size for p in processors) - min_feature_file_size = min( - p.min_feature_file_size for p in processors) + min_feature_file_size = min(p.min_feature_file_size for p in + processors) return (requires_training, split_training_data_by_client, min_extractor_file_size, min_feature_file_size) @@ -23,32 +24,54 @@ class MultipleExtractor(Extractor): return groups def train_one(self, e, training_data, extractor_file, apply=False): + """Trains one extractor and optionally applies the extractor on the + training data after training. + + Parameters + ---------- + e : :any:`Extractor` + The extractor to train. The extractor should be able to save itself + in an opened hdf5 file. + training_data : [object] or [[object]] + The data to be used for training. + extractor_file : :any:`bob.io.base.HDF5File` + The opened hdf5 file to save the trained extractor inside. + apply : :obj:`bool`, optional + If ``True``, the extractor is applied to the training data after it + is trained and the data is returned. + + Returns + ------- + None or [object] or [[object]] + Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the + transformed ``training_data``. + """ if not e.requires_training: - return + # do nothing since e does not require training! + pass # if any of the extractors require splitting the data, the # split_training_data_by_client is True. - if e.split_training_data_by_client: + elif e.split_training_data_by_client: e.train(training_data, extractor_file) - if not apply: - return - training_data = [[e(d) for d in datalist] - for datalist in training_data] # when no extractor needs splitting elif not self.split_training_data_by_client: e.train(training_data, extractor_file) - if not apply: - return - training_data = [e(d) for d in training_data] # when e here wants it flat but the data is split else: # make training_data flat - aligned_training_data = [d for datalist in training_data for d in - datalist] - e.train(aligned_training_data, extractor_file) - if not apply: - return + flat_training_data = [d for datalist in training_data for d in + datalist] + e.train(flat_training_data, extractor_file) + + if not apply: + return + + # prepare the training data for the next extractor + if self.split_training_data_by_client: training_data = [[e(d) for d in datalist] for datalist in training_data] + else: + training_data = [e(d) for d in training_data] return training_data def load(self, extractor_file): @@ -104,10 +127,12 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor): def train(self, training_data, extractor_file): with HDF5File(extractor_file, 'w') as f: groups = self.get_extractor_groups() - for e, group in zip(self.processors, groups): + for i, (e, group) in enumerate(zip(self.processors, groups)): + apply = i != len(self.processors) - 1 f.create_group(group) f.cd(group) - training_data = self.train_one(e, training_data, f, apply=True) + training_data = self.train_one(e, training_data, f, + apply=apply) f.cd('..') def read_feature(self, feature_file): diff --git a/bob/bio/base/test/dummy/extractor.py b/bob/bio/base/test/dummy/extractor.py index a3aaf6f7ea04347393db8cbc8efd1dac95e98000..eca7517ca8a46676074f93410d95b98d71757721 100644 --- a/bob/bio/base/test/dummy/extractor.py +++ b/bob/bio/base/test/dummy/extractor.py @@ -1,5 +1,5 @@ import numpy -import bob.io.base +import bob.bio.base from bob.bio.base.extractor import Extractor @@ -12,10 +12,10 @@ class DummyExtractor (Extractor): def train(self, train_data, extractor_file): assert isinstance(train_data, list) - bob.io.base.save(_data, extractor_file) + bob.bio.base.save(_data, extractor_file) def load(self, extractor_file): - data = bob.io.base.load(extractor_file) + data = bob.bio.base.load(extractor_file) assert (_data == data).all() self.model = True diff --git a/bob/bio/base/test/test_stacks.py b/bob/bio/base/test/test_stacks.py index 926901382af5455ee292ace46abc885582baaaa6..cd6e0f534fe2eca8c236685a44a2c458236c5185 100644 --- a/bob/bio/base/test/test_stacks.py +++ b/bob/bio/base/test/test_stacks.py @@ -1,11 +1,13 @@ from functools import partial import numpy as np +import tempfile from bob.bio.base.utils.processors import ( SequentialProcessor, ParallelProcessor) from bob.bio.base.preprocessor import ( SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor) from bob.bio.base.extractor import ( SequentialExtractor, ParallelExtractor, CallableExtractor) +from bob.bio.base.test.dummy.extractor import extractor as dummy_extractor DATA = [0, 1, 2, 3, 4] PROCESSORS = [partial(np.power, 2), np.mean] @@ -43,3 +45,23 @@ def test_extractors(): proc = ParallelExtractor(processors) data = proc(DATA) assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA)) + + +def test_sequential_trainable_extractors(): + processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor] + proc = SequentialExtractor(processors) + with tempfile.NamedTemporaryFile(suffix='.hdf5') as f: + proc.train(DATA, f.name) + proc.load(f.name) + data = proc(DATA) + assert np.allclose(data, SEQ_DATA) + + +def test_parallel_trainable_extractors(): + processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor] + proc = ParallelExtractor(processors) + with tempfile.NamedTemporaryFile(suffix='.hdf5') as f: + proc.train(DATA, f.name) + proc.load(f.name) + data = proc(np.array(DATA)) + assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))