Skip to content
Snippets Groups Projects
Commit 7c2bc523 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Implement hdf5 io for extractors

Remove try and catches
Accept kwargs in init signature of processors
parent b8efca23
No related branches found
No related tags found
1 merge request!102Add sequential and parallel processors, pre-processors, and extractors
Pipeline #
from ..utils.processors import SequentialProcessor, ParallelProcessor
from .Extractor import Extractor
from bob.io.base import HDF5File
class MultipleExtractor(Extractor):
......@@ -8,34 +9,35 @@ class MultipleExtractor(Extractor):
def get_attributes(self, processors):
requires_training = any((p.requires_training for p in processors))
split_training_data_by_client = any(
(p.split_training_data_by_client for p in processors))
min_extractor_file_size = min(
(p.min_extractor_file_size for p in processors))
min_feature_file_size = min(
(p.min_feature_file_size for p in processors))
split_training_data_by_client = any(p.split_training_data_by_client for p
in processors)
min_extractor_file_size = min(p.min_extractor_file_size for p in
processors)
min_feature_file_size = min(p.min_feature_file_size for p in processors)
return (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size)
def get_extractor_files(self, extractor_file):
paths = [extractor_file]
paths += [extractor_file +
'_{}.hdf5'.format(i) for i in range(1, len(self.processors))]
return paths
def get_extractor_groups(self):
groups = ['E_{}'.format(i + 1) for i in range(len(self.processors))]
return groups
def train_one(self, e, training_data, extractor_file, apply=False):
if not e.requires_training:
return
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
if e.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [[e(d) for d in datalist] for datalist in training_data]
# when no extractor needs splitting
elif not self.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [e(d) for d in training_data]
# when e here wants it flat but the data is split
else:
# make training_data flat
training_data_len = [len(datalist) for datalist in training_data]
......@@ -55,9 +57,12 @@ class MultipleExtractor(Extractor):
return training_data
def load(self, extractor_file):
paths = self.get_extractor_files(extractor_file)
for e, path in zip(self.processors, paths):
e.load(path)
with HDF5File(extractor_file) as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
f.cd(group)
e.load(f)
f.cd('..')
class SequentialExtractor(SequentialProcessor, MultipleExtractor):
......@@ -66,21 +71,24 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
def __init__(self, processors):
(requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) = self.get_attributes(
processors)
min_extractor_file_size, min_feature_file_size) = \
self.get_attributes(processors)
SequentialProcessor.__init__(self, processors)
MultipleExtractor.__init__(
self,
super(SequentialExtractor, self).__init__(
processors=processors,
requires_training=requires_training,
split_training_data_by_client=split_training_data_by_client,
min_extractor_file_size=min_extractor_file_size,
min_feature_file_size=min_feature_file_size)
def train(self, training_data, extractor_file):
paths = self.get_extractor_files(extractor_file)
for e, path in zip(self.processors, paths):
training_data = self.train_one(e, training_data, path, apply=True)
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
f.create_group(group)
f.cd(group)
training_data = self.train_one(e, training_data, f, apply=True)
f.cd('..')
class ParallelExtractor(ParallelProcessor, MultipleExtractor):
......@@ -92,18 +100,21 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor):
min_extractor_file_size, min_feature_file_size) = self.get_attributes(
processors)
ParallelProcessor.__init__(self, processors)
MultipleExtractor.__init__(
self,
super(ParallelExtractor, self).__init__(
processors=processors,
requires_training=requires_training,
split_training_data_by_client=split_training_data_by_client,
min_extractor_file_size=min_extractor_file_size,
min_feature_file_size=min_feature_file_size)
def train(self, training_data, extractor_file):
paths = self.get_extractor_files(extractor_file)
for e, path in zip(self.processors, paths):
self.train_one(e, training_data, path)
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
f.create_group(group)
f.cd(group)
self.train_one(e, training_data, f, apply=False)
f.cd('..')
class CallableExtractor(Extractor):
......
......@@ -26,12 +26,8 @@ class ParallelPreprocessor(ParallelProcessor, Preprocessor):
__doc__ = ParallelProcessor.__doc__
def __init__(self, processors, **kwargs):
min_preprocessed_file_size = 1000
try:
min_preprocessed_file_size = min(
(p.min_preprocessed_file_size for p in processors))
except AttributeError:
pass
min_preprocessed_file_size = min((p.min_preprocessed_file_size for p in
processors))
ParallelProcessor.__init__(self, processors)
Preprocessor.__init__(
......@@ -48,14 +44,21 @@ class CallablePreprocessor(Preprocessor):
Attributes
----------
accepts_annotations : bool
If False, annotations are not passed to the callable.
callable : object
Anything that is callable. It will be used as a preprocessor in
bob.bio.base.
"""
def __init__(self, callable, **kwargs):
super(CallablePreprocessor, self).__init__(**kwargs)
def __init__(self, callable, accepts_annotations=True, **kwargs):
super(CallablePreprocessor, self).__init__(
callable=callable, accepts_annotations=accepts_annotations, **kwargs)
self.callable = callable
self.accepts_annotations = accepts_annotations
def __call__(self, data, annotations):
return self.callable(data)
if self.accepts_annotations:
return self.callable(data, annotations)
else:
return self.callable(data)
......@@ -24,7 +24,7 @@ def test_processors():
def test_preprocessors():
processors = [CallablePreprocessor(p) for p in PROCESSORS]
processors = [CallablePreprocessor(p, False) for p in PROCESSORS]
proc = SequentialPreprocessor(processors)
data = proc(DATA, None)
assert np.allclose(data, SEQ_DATA)
......
......@@ -11,7 +11,7 @@ class SequentialProcessor(object):
A list of processors to apply.
"""
def __init__(self, processors):
def __init__(self, processors, **kwargs):
super(SequentialProcessor, self).__init__()
self.processors = processors
......@@ -32,10 +32,7 @@ class SequentialProcessor(object):
The processed data.
"""
for processor in self.processors:
try:
data = processor(data, **kwargs)
except ValueError:
data = processor(data)
data = processor(data, **kwargs)
return data
......@@ -51,7 +48,7 @@ class ParallelProcessor(object):
If True (default), :any:`numpy.hstack` is called on the list of outputs.
"""
def __init__(self, processors, stack=True):
def __init__(self, processors, stack=True, **kwargs):
super(ParallelProcessor, self).__init__()
self.processors = processors
self.stack = stack
......@@ -74,10 +71,7 @@ class ParallelProcessor(object):
"""
output = []
for processor in self.processors:
try:
out = processor(data, **kwargs)
except ValueError:
out = processor(data)
out = processor(data, **kwargs)
output.append(out)
if self.stack:
output = numpy.hstack(output)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment