Skip to content
Snippets Groups Projects
Commit ac1c9afc authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

implement some tests too

parent 278917a4
No related branches found
No related tags found
1 merge request!112SequentialExtractor: Apply extractor on training data always when apply=True
Pipeline #
......@@ -33,7 +33,7 @@ class MultipleExtractor(Extractor):
training_data = [e(d) for d in training_data]
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
if e.split_training_data_by_client:
elif e.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
......@@ -62,7 +62,8 @@ class MultipleExtractor(Extractor):
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
f.cd(group)
e.load(f)
if e.requires_training:
e.load(f)
f.cd('..')
......@@ -110,10 +111,15 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
for i, (e, group) in enumerate(zip(self.processors, groups)):
if i == len(self.processors) - 1:
apply = False
else:
apply = True
f.create_group(group)
f.cd(group)
training_data = self.train_one(e, training_data, f, apply=True)
training_data = self.train_one(e, training_data, f,
apply=apply)
f.cd('..')
def read_feature(self, feature_file):
......
import numpy
import bob.io.base
import bob.bio.base
from bob.bio.base.extractor import Extractor
......@@ -12,10 +12,10 @@ class DummyExtractor (Extractor):
def train(self, train_data, extractor_file):
assert isinstance(train_data, list)
bob.io.base.save(_data, extractor_file)
bob.bio.base.save(_data, extractor_file)
def load(self, extractor_file):
data = bob.io.base.load(extractor_file)
data = bob.bio.base.load(extractor_file)
assert (_data == data).all()
self.model = True
......
from functools import partial
import numpy as np
import tempfile
from bob.bio.base.utils.processors import (
SequentialProcessor, ParallelProcessor)
from bob.bio.base.preprocessor import (
SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor)
from bob.bio.base.extractor import (
SequentialExtractor, ParallelExtractor, CallableExtractor)
from bob.bio.base.test.dummy.extractor import extractor as dummy_extractor
DATA = [0, 1, 2, 3, 4]
PROCESSORS = [partial(np.power, 2), np.mean]
......@@ -43,3 +45,13 @@ def test_extractors():
proc = ParallelExtractor(processors)
data = proc(DATA)
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
def test_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = SequentialExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(DATA)
assert np.allclose(data, SEQ_DATA)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment