Skip to content
Snippets Groups Projects
Commit 6a329c1b authored by Manuel Günther's avatar Manuel Günther
Browse files

Merge branch 'processors' into 'master'

SequentialExtractor: Apply extractor on training data always when apply=True

See merge request !112
parents 04aace9a 1a500e15
No related branches found
No related tags found
1 merge request!112SequentialExtractor: Apply extractor on training data always when apply=True
Pipeline #
...@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor): ...@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
"""Base class for SequentialExtractor and ParallelExtractor. This class is """Base class for SequentialExtractor and ParallelExtractor. This class is
not meant to be used directly.""" not meant to be used directly."""
def get_attributes(self, processors): @staticmethod
def get_attributes(processors):
requires_training = any(p.requires_training for p in processors) requires_training = any(p.requires_training for p in processors)
split_training_data_by_client = any(p.split_training_data_by_client for split_training_data_by_client = any(p.split_training_data_by_client for
p in processors) p in processors)
min_extractor_file_size = min(p.min_extractor_file_size for p in min_extractor_file_size = min(p.min_extractor_file_size for p in
processors) processors)
min_feature_file_size = min( min_feature_file_size = min(p.min_feature_file_size for p in
p.min_feature_file_size for p in processors) processors)
return (requires_training, split_training_data_by_client, return (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) min_extractor_file_size, min_feature_file_size)
...@@ -23,32 +24,54 @@ class MultipleExtractor(Extractor): ...@@ -23,32 +24,54 @@ class MultipleExtractor(Extractor):
return groups return groups
def train_one(self, e, training_data, extractor_file, apply=False): def train_one(self, e, training_data, extractor_file, apply=False):
"""Trains one extractor and optionally applies the extractor on the
training data after training.
Parameters
----------
e : :any:`Extractor`
The extractor to train. The extractor should be able to save itself
in an opened hdf5 file.
training_data : [object] or [[object]]
The data to be used for training.
extractor_file : :any:`bob.io.base.HDF5File`
The opened hdf5 file to save the trained extractor inside.
apply : :obj:`bool`, optional
If ``True``, the extractor is applied to the training data after it
is trained and the data is returned.
Returns
-------
None or [object] or [[object]]
Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
transformed ``training_data``.
"""
if not e.requires_training: if not e.requires_training:
return # do nothing since e does not require training!
pass
# if any of the extractors require splitting the data, the # if any of the extractors require splitting the data, the
# split_training_data_by_client is True. # split_training_data_by_client is True.
if e.split_training_data_by_client: elif e.split_training_data_by_client:
e.train(training_data, extractor_file) e.train(training_data, extractor_file)
if not apply:
return
training_data = [[e(d) for d in datalist]
for datalist in training_data]
# when no extractor needs splitting # when no extractor needs splitting
elif not self.split_training_data_by_client: elif not self.split_training_data_by_client:
e.train(training_data, extractor_file) e.train(training_data, extractor_file)
if not apply:
return
training_data = [e(d) for d in training_data]
# when e here wants it flat but the data is split # when e here wants it flat but the data is split
else: else:
# make training_data flat # make training_data flat
aligned_training_data = [d for datalist in training_data for d in flat_training_data = [d for datalist in training_data for d in
datalist] datalist]
e.train(aligned_training_data, extractor_file) e.train(flat_training_data, extractor_file)
if not apply:
return if not apply:
return
# prepare the training data for the next extractor
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist] training_data = [[e(d) for d in datalist]
for datalist in training_data] for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
return training_data return training_data
def load(self, extractor_file): def load(self, extractor_file):
...@@ -104,10 +127,12 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor): ...@@ -104,10 +127,12 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
def train(self, training_data, extractor_file): def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f: with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups() groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups): for i, (e, group) in enumerate(zip(self.processors, groups)):
apply = i != len(self.processors) - 1
f.create_group(group) f.create_group(group)
f.cd(group) f.cd(group)
training_data = self.train_one(e, training_data, f, apply=True) training_data = self.train_one(e, training_data, f,
apply=apply)
f.cd('..') f.cd('..')
def read_feature(self, feature_file): def read_feature(self, feature_file):
......
import numpy import numpy
import bob.io.base import bob.bio.base
from bob.bio.base.extractor import Extractor from bob.bio.base.extractor import Extractor
...@@ -12,10 +12,10 @@ class DummyExtractor (Extractor): ...@@ -12,10 +12,10 @@ class DummyExtractor (Extractor):
def train(self, train_data, extractor_file): def train(self, train_data, extractor_file):
assert isinstance(train_data, list) assert isinstance(train_data, list)
bob.io.base.save(_data, extractor_file) bob.bio.base.save(_data, extractor_file)
def load(self, extractor_file): def load(self, extractor_file):
data = bob.io.base.load(extractor_file) data = bob.bio.base.load(extractor_file)
assert (_data == data).all() assert (_data == data).all()
self.model = True self.model = True
......
from functools import partial from functools import partial
import numpy as np import numpy as np
import tempfile
from bob.bio.base.utils.processors import ( from bob.bio.base.utils.processors import (
SequentialProcessor, ParallelProcessor) SequentialProcessor, ParallelProcessor)
from bob.bio.base.preprocessor import ( from bob.bio.base.preprocessor import (
SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor) SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor)
from bob.bio.base.extractor import ( from bob.bio.base.extractor import (
SequentialExtractor, ParallelExtractor, CallableExtractor) SequentialExtractor, ParallelExtractor, CallableExtractor)
from bob.bio.base.test.dummy.extractor import extractor as dummy_extractor
DATA = [0, 1, 2, 3, 4] DATA = [0, 1, 2, 3, 4]
PROCESSORS = [partial(np.power, 2), np.mean] PROCESSORS = [partial(np.power, 2), np.mean]
...@@ -43,3 +45,23 @@ def test_extractors(): ...@@ -43,3 +45,23 @@ def test_extractors():
proc = ParallelExtractor(processors) proc = ParallelExtractor(processors)
data = proc(DATA) data = proc(DATA)
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA)) assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
def test_sequential_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = SequentialExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(DATA)
assert np.allclose(data, SEQ_DATA)
def test_parallel_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = ParallelExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(np.array(DATA))
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment