Skip to content
Snippets Groups Projects
Commit 6a329c1b authored by Manuel Günther's avatar Manuel Günther
Browse files

Merge branch 'processors' into 'master'

SequentialExtractor: Apply extractor on training data always when apply=True

See merge request !112
parents 04aace9a 1a500e15
No related branches found
No related tags found
1 merge request!112SequentialExtractor: Apply extractor on training data always when apply=True
Pipeline #
......@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
"""Base class for SequentialExtractor and ParallelExtractor. This class is
not meant to be used directly."""
def get_attributes(self, processors):
@staticmethod
def get_attributes(processors):
requires_training = any(p.requires_training for p in processors)
split_training_data_by_client = any(p.split_training_data_by_client for
p in processors)
min_extractor_file_size = min(p.min_extractor_file_size for p in
processors)
min_feature_file_size = min(
p.min_feature_file_size for p in processors)
min_feature_file_size = min(p.min_feature_file_size for p in
processors)
return (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size)
......@@ -23,32 +24,54 @@ class MultipleExtractor(Extractor):
return groups
def train_one(self, e, training_data, extractor_file, apply=False):
"""Trains one extractor and optionally applies the extractor on the
training data after training.
Parameters
----------
e : :any:`Extractor`
The extractor to train. The extractor should be able to save itself
in an opened hdf5 file.
training_data : [object] or [[object]]
The data to be used for training.
extractor_file : :any:`bob.io.base.HDF5File`
The opened hdf5 file to save the trained extractor inside.
apply : :obj:`bool`, optional
If ``True``, the extractor is applied to the training data after it
is trained and the data is returned.
Returns
-------
None or [object] or [[object]]
Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
transformed ``training_data``.
"""
if not e.requires_training:
return
# do nothing since e does not require training!
pass
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
if e.split_training_data_by_client:
elif e.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [[e(d) for d in datalist]
for datalist in training_data]
# when no extractor needs splitting
elif not self.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [e(d) for d in training_data]
# when e here wants it flat but the data is split
else:
# make training_data flat
aligned_training_data = [d for datalist in training_data for d in
datalist]
e.train(aligned_training_data, extractor_file)
if not apply:
return
flat_training_data = [d for datalist in training_data for d in
datalist]
e.train(flat_training_data, extractor_file)
if not apply:
return
# prepare the training data for the next extractor
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist]
for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
return training_data
def load(self, extractor_file):
......@@ -104,10 +127,12 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
for i, (e, group) in enumerate(zip(self.processors, groups)):
apply = i != len(self.processors) - 1
f.create_group(group)
f.cd(group)
training_data = self.train_one(e, training_data, f, apply=True)
training_data = self.train_one(e, training_data, f,
apply=apply)
f.cd('..')
def read_feature(self, feature_file):
......
import numpy
import bob.io.base
import bob.bio.base
from bob.bio.base.extractor import Extractor
......@@ -12,10 +12,10 @@ class DummyExtractor (Extractor):
def train(self, train_data, extractor_file):
assert isinstance(train_data, list)
bob.io.base.save(_data, extractor_file)
bob.bio.base.save(_data, extractor_file)
def load(self, extractor_file):
data = bob.io.base.load(extractor_file)
data = bob.bio.base.load(extractor_file)
assert (_data == data).all()
self.model = True
......
from functools import partial
import numpy as np
import tempfile
from bob.bio.base.utils.processors import (
SequentialProcessor, ParallelProcessor)
from bob.bio.base.preprocessor import (
SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor)
from bob.bio.base.extractor import (
SequentialExtractor, ParallelExtractor, CallableExtractor)
from bob.bio.base.test.dummy.extractor import extractor as dummy_extractor
DATA = [0, 1, 2, 3, 4]
PROCESSORS = [partial(np.power, 2), np.mean]
......@@ -43,3 +45,23 @@ def test_extractors():
proc = ParallelExtractor(processors)
data = proc(DATA)
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
def test_sequential_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = SequentialExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(DATA)
assert np.allclose(data, SEQ_DATA)
def test_parallel_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = ParallelExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(np.array(DATA))
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment