Skip to content
Snippets Groups Projects
Commit ac1c9afc authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

implement some tests too

parent 278917a4
No related branches found
No related tags found
1 merge request!112SequentialExtractor: Apply extractor on training data always when apply=True
Pipeline #
...@@ -33,7 +33,7 @@ class MultipleExtractor(Extractor): ...@@ -33,7 +33,7 @@ class MultipleExtractor(Extractor):
training_data = [e(d) for d in training_data] training_data = [e(d) for d in training_data]
# if any of the extractors require splitting the data, the # if any of the extractors require splitting the data, the
# split_training_data_by_client is True. # split_training_data_by_client is True.
if e.split_training_data_by_client: elif e.split_training_data_by_client:
e.train(training_data, extractor_file) e.train(training_data, extractor_file)
if not apply: if not apply:
return return
...@@ -62,7 +62,8 @@ class MultipleExtractor(Extractor): ...@@ -62,7 +62,8 @@ class MultipleExtractor(Extractor):
groups = self.get_extractor_groups() groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups): for e, group in zip(self.processors, groups):
f.cd(group) f.cd(group)
e.load(f) if e.requires_training:
e.load(f)
f.cd('..') f.cd('..')
...@@ -110,10 +111,15 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor): ...@@ -110,10 +111,15 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
def train(self, training_data, extractor_file): def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f: with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups() groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups): for i, (e, group) in enumerate(zip(self.processors, groups)):
if i == len(self.processors) - 1:
apply = False
else:
apply = True
f.create_group(group) f.create_group(group)
f.cd(group) f.cd(group)
training_data = self.train_one(e, training_data, f, apply=True) training_data = self.train_one(e, training_data, f,
apply=apply)
f.cd('..') f.cd('..')
def read_feature(self, feature_file): def read_feature(self, feature_file):
......
import numpy import numpy
import bob.io.base import bob.bio.base
from bob.bio.base.extractor import Extractor from bob.bio.base.extractor import Extractor
...@@ -12,10 +12,10 @@ class DummyExtractor (Extractor): ...@@ -12,10 +12,10 @@ class DummyExtractor (Extractor):
def train(self, train_data, extractor_file): def train(self, train_data, extractor_file):
assert isinstance(train_data, list) assert isinstance(train_data, list)
bob.io.base.save(_data, extractor_file) bob.bio.base.save(_data, extractor_file)
def load(self, extractor_file): def load(self, extractor_file):
data = bob.io.base.load(extractor_file) data = bob.bio.base.load(extractor_file)
assert (_data == data).all() assert (_data == data).all()
self.model = True self.model = True
......
from functools import partial from functools import partial
import numpy as np import numpy as np
import tempfile
from bob.bio.base.utils.processors import ( from bob.bio.base.utils.processors import (
SequentialProcessor, ParallelProcessor) SequentialProcessor, ParallelProcessor)
from bob.bio.base.preprocessor import ( from bob.bio.base.preprocessor import (
SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor) SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor)
from bob.bio.base.extractor import ( from bob.bio.base.extractor import (
SequentialExtractor, ParallelExtractor, CallableExtractor) SequentialExtractor, ParallelExtractor, CallableExtractor)
from bob.bio.base.test.dummy.extractor import extractor as dummy_extractor
DATA = [0, 1, 2, 3, 4] DATA = [0, 1, 2, 3, 4]
PROCESSORS = [partial(np.power, 2), np.mean] PROCESSORS = [partial(np.power, 2), np.mean]
...@@ -43,3 +45,13 @@ def test_extractors(): ...@@ -43,3 +45,13 @@ def test_extractors():
proc = ParallelExtractor(processors) proc = ParallelExtractor(processors)
data = proc(DATA) data = proc(DATA)
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA)) assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
def test_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = SequentialExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(DATA)
assert np.allclose(data, SEQ_DATA)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment