diff --git a/bob/bio/base/extractor/stacks.py b/bob/bio/base/extractor/stacks.py index 0021d4ad4b78d108a3cbe437d5bcec8e7b1523e6..780dafdb47b211ad0b8aa70af2b0c559b9f636e1 100644 --- a/bob/bio/base/extractor/stacks.py +++ b/bob/bio/base/extractor/stacks.py @@ -75,6 +75,8 @@ class MultipleExtractor(Extractor): return training_data def load(self, extractor_file): + if not self.requires_training: + return with HDF5File(extractor_file) as f: groups = self.get_extractor_groups() for e, group in zip(self.processors, groups): @@ -111,7 +113,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor): True """ - def __init__(self, processors): + def __init__(self, processors, **kwargs): (requires_training, split_training_data_by_client, min_extractor_file_size, min_feature_file_size) = \ @@ -122,7 +124,8 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor): requires_training=requires_training, split_training_data_by_client=split_training_data_by_client, min_extractor_file_size=min_extractor_file_size, - min_feature_file_size=min_feature_file_size) + min_feature_file_size=min_feature_file_size, + **kwargs) def train(self, training_data, extractor_file): with HDF5File(extractor_file, 'w') as f: @@ -179,7 +182,7 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor): [ 1. , 2. , 3. , 0.5, 1. , 1.5]]) """ - def __init__(self, processors): + def __init__(self, processors, **kwargs): (requires_training, split_training_data_by_client, min_extractor_file_size, min_feature_file_size) = self.get_attributes( @@ -190,7 +193,8 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor): requires_training=requires_training, split_training_data_by_client=split_training_data_by_client, min_extractor_file_size=min_extractor_file_size, - min_feature_file_size=min_feature_file_size) + min_feature_file_size=min_feature_file_size, + **kwargs) def train(self, training_data, extractor_file): with HDF5File(extractor_file, 'w') as f: diff --git a/bob/bio/base/test/test_stacks.py b/bob/bio/base/test/test_stacks.py index cd6e0f534fe2eca8c236685a44a2c458236c5185..a296a9a33ebed07b3176d5955c84a7e254f9ca71 100644 --- a/bob/bio/base/test/test_stacks.py +++ b/bob/bio/base/test/test_stacks.py @@ -39,10 +39,12 @@ def test_preprocessors(): def test_extractors(): processors = [CallableExtractor(p) for p in PROCESSORS] proc = SequentialExtractor(processors) + proc.load(None) data = proc(DATA) assert np.allclose(data, SEQ_DATA) proc = ParallelExtractor(processors) + proc.load(None) data = proc(DATA) assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA)) diff --git a/bob/bio/base/utils/processors.py b/bob/bio/base/utils/processors.py index ca04d37cf3c2b34df5b188a49c9dd0e9a5f2908e..b01a953dc34e5962bb74e03801d60f9227ed2838 100644 --- a/bob/bio/base/utils/processors.py +++ b/bob/bio/base/utils/processors.py @@ -26,7 +26,7 @@ class SequentialProcessor(object): """ def __init__(self, processors, **kwargs): - super(SequentialProcessor, self).__init__() + super(SequentialProcessor, self).__init__(**kwargs) self.processors = processors def __call__(self, data, **kwargs): @@ -86,7 +86,7 @@ class ParallelProcessor(object): """ def __init__(self, processors, **kwargs): - super(ParallelProcessor, self).__init__() + super(ParallelProcessor, self).__init__(**kwargs) self.processors = processors def __call__(self, data, **kwargs):