Commit c851d096 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Stacks of extractors: Fix when no extractor requires training

parent 24f53216
Pipeline #14019 passed with stages
in 10 minutes and 43 seconds
...@@ -75,6 +75,8 @@ class MultipleExtractor(Extractor): ...@@ -75,6 +75,8 @@ class MultipleExtractor(Extractor):
return training_data return training_data
def load(self, extractor_file): def load(self, extractor_file):
if not self.requires_training:
return
with HDF5File(extractor_file) as f: with HDF5File(extractor_file) as f:
groups = self.get_extractor_groups() groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups): for e, group in zip(self.processors, groups):
...@@ -111,7 +113,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor): ...@@ -111,7 +113,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
True True
""" """
def __init__(self, processors): def __init__(self, processors, **kwargs):
(requires_training, split_training_data_by_client, (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) = \ min_extractor_file_size, min_feature_file_size) = \
...@@ -122,7 +124,8 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor): ...@@ -122,7 +124,8 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
requires_training=requires_training, requires_training=requires_training,
split_training_data_by_client=split_training_data_by_client, split_training_data_by_client=split_training_data_by_client,
min_extractor_file_size=min_extractor_file_size, min_extractor_file_size=min_extractor_file_size,
min_feature_file_size=min_feature_file_size) min_feature_file_size=min_feature_file_size,
**kwargs)
def train(self, training_data, extractor_file): def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f: with HDF5File(extractor_file, 'w') as f:
...@@ -179,7 +182,7 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor): ...@@ -179,7 +182,7 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor):
[ 1. , 2. , 3. , 0.5, 1. , 1.5]]) [ 1. , 2. , 3. , 0.5, 1. , 1.5]])
""" """
def __init__(self, processors): def __init__(self, processors, **kwargs):
(requires_training, split_training_data_by_client, (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) = self.get_attributes( min_extractor_file_size, min_feature_file_size) = self.get_attributes(
...@@ -190,7 +193,8 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor): ...@@ -190,7 +193,8 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor):
requires_training=requires_training, requires_training=requires_training,
split_training_data_by_client=split_training_data_by_client, split_training_data_by_client=split_training_data_by_client,
min_extractor_file_size=min_extractor_file_size, min_extractor_file_size=min_extractor_file_size,
min_feature_file_size=min_feature_file_size) min_feature_file_size=min_feature_file_size,
**kwargs)
def train(self, training_data, extractor_file): def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f: with HDF5File(extractor_file, 'w') as f:
......
...@@ -39,10 +39,12 @@ def test_preprocessors(): ...@@ -39,10 +39,12 @@ def test_preprocessors():
def test_extractors(): def test_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] processors = [CallableExtractor(p) for p in PROCESSORS]
proc = SequentialExtractor(processors) proc = SequentialExtractor(processors)
proc.load(None)
data = proc(DATA) data = proc(DATA)
assert np.allclose(data, SEQ_DATA) assert np.allclose(data, SEQ_DATA)
proc = ParallelExtractor(processors) proc = ParallelExtractor(processors)
proc.load(None)
data = proc(DATA) data = proc(DATA)
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA)) assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
......
...@@ -26,7 +26,7 @@ class SequentialProcessor(object): ...@@ -26,7 +26,7 @@ class SequentialProcessor(object):
""" """
def __init__(self, processors, **kwargs): def __init__(self, processors, **kwargs):
super(SequentialProcessor, self).__init__() super(SequentialProcessor, self).__init__(**kwargs)
self.processors = processors self.processors = processors
def __call__(self, data, **kwargs): def __call__(self, data, **kwargs):
...@@ -86,7 +86,7 @@ class ParallelProcessor(object): ...@@ -86,7 +86,7 @@ class ParallelProcessor(object):
""" """
def __init__(self, processors, **kwargs): def __init__(self, processors, **kwargs):
super(ParallelProcessor, self).__init__() super(ParallelProcessor, self).__init__(**kwargs)
self.processors = processors self.processors = processors
def __call__(self, data, **kwargs): def __call__(self, data, **kwargs):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment