Commit c851d096 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Stacks of extractors: Fix when no extractor requires training

parent 24f53216
Pipeline #14019 passed with stages
in 10 minutes and 43 seconds
......@@ -75,6 +75,8 @@ class MultipleExtractor(Extractor):
return training_data
def load(self, extractor_file):
if not self.requires_training:
return
with HDF5File(extractor_file) as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
......@@ -111,7 +113,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
True
"""
def __init__(self, processors):
def __init__(self, processors, **kwargs):
(requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) = \
......@@ -122,7 +124,8 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
requires_training=requires_training,
split_training_data_by_client=split_training_data_by_client,
min_extractor_file_size=min_extractor_file_size,
min_feature_file_size=min_feature_file_size)
min_feature_file_size=min_feature_file_size,
**kwargs)
def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f:
......@@ -179,7 +182,7 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor):
[ 1. , 2. , 3. , 0.5, 1. , 1.5]])
"""
def __init__(self, processors):
def __init__(self, processors, **kwargs):
(requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) = self.get_attributes(
......@@ -190,7 +193,8 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor):
requires_training=requires_training,
split_training_data_by_client=split_training_data_by_client,
min_extractor_file_size=min_extractor_file_size,
min_feature_file_size=min_feature_file_size)
min_feature_file_size=min_feature_file_size,
**kwargs)
def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f:
......
......@@ -39,10 +39,12 @@ def test_preprocessors():
def test_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS]
proc = SequentialExtractor(processors)
proc.load(None)
data = proc(DATA)
assert np.allclose(data, SEQ_DATA)
proc = ParallelExtractor(processors)
proc.load(None)
data = proc(DATA)
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
......
......@@ -26,7 +26,7 @@ class SequentialProcessor(object):
"""
def __init__(self, processors, **kwargs):
super(SequentialProcessor, self).__init__()
super(SequentialProcessor, self).__init__(**kwargs)
self.processors = processors
def __call__(self, data, **kwargs):
......@@ -86,7 +86,7 @@ class ParallelProcessor(object):
"""
def __init__(self, processors, **kwargs):
super(ParallelProcessor, self).__init__()
super(ParallelProcessor, self).__init__(**kwargs)
self.processors = processors
def __call__(self, data, **kwargs):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment