Commit ac1c9afc authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

implement some tests too

parent 278917a4
Pipeline #13793 canceled with stages
in 8 minutes and 35 seconds
......@@ -33,7 +33,7 @@ class MultipleExtractor(Extractor):
training_data = [e(d) for d in training_data]
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
if e.split_training_data_by_client:
elif e.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
......@@ -62,7 +62,8 @@ class MultipleExtractor(Extractor):
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
f.cd(group)
e.load(f)
if e.requires_training:
e.load(f)
f.cd('..')
......@@ -110,10 +111,15 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
for i, (e, group) in enumerate(zip(self.processors, groups)):
if i == len(self.processors) - 1:
apply = False
else:
apply = True
f.create_group(group)
f.cd(group)
training_data = self.train_one(e, training_data, f, apply=True)
training_data = self.train_one(e, training_data, f,
apply=apply)
f.cd('..')
def read_feature(self, feature_file):
......
import numpy
import bob.io.base
import bob.bio.base
from bob.bio.base.extractor import Extractor
......@@ -12,10 +12,10 @@ class DummyExtractor (Extractor):
def train(self, train_data, extractor_file):
assert isinstance(train_data, list)
bob.io.base.save(_data, extractor_file)
bob.bio.base.save(_data, extractor_file)
def load(self, extractor_file):
data = bob.io.base.load(extractor_file)
data = bob.bio.base.load(extractor_file)
assert (_data == data).all()
self.model = True
......
from functools import partial
import numpy as np
import tempfile
from bob.bio.base.utils.processors import (
SequentialProcessor, ParallelProcessor)
from bob.bio.base.preprocessor import (
SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor)
from bob.bio.base.extractor import (
SequentialExtractor, ParallelExtractor, CallableExtractor)
from bob.bio.base.test.dummy.extractor import extractor as dummy_extractor
DATA = [0, 1, 2, 3, 4]
PROCESSORS = [partial(np.power, 2), np.mean]
......@@ -43,3 +45,13 @@ def test_extractors():
proc = ParallelExtractor(processors)
data = proc(DATA)
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
def test_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = SequentialExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(DATA)
assert np.allclose(data, SEQ_DATA)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment