Skip to content
Snippets Groups Projects
Commit 31bdc2e9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

implement tests for parallel extractor training too

parent ac1c9afc
No related branches found
No related tags found
1 merge request!112SequentialExtractor: Apply extractor on training data always when apply=True
Pipeline #
...@@ -47,7 +47,7 @@ def test_extractors(): ...@@ -47,7 +47,7 @@ def test_extractors():
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA)) assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
def test_trainable_extractors(): def test_sequential_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor] processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = SequentialExtractor(processors) proc = SequentialExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f: with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
...@@ -55,3 +55,13 @@ def test_trainable_extractors(): ...@@ -55,3 +55,13 @@ def test_trainable_extractors():
proc.load(f.name) proc.load(f.name)
data = proc(DATA) data = proc(DATA)
assert np.allclose(data, SEQ_DATA) assert np.allclose(data, SEQ_DATA)
def test_parallel_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = ParallelExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(np.array(DATA))
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment