Skip to content
Snippets Groups Projects
Commit 39a5713b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'master' into 'bob-extension-config-file'

# Conflicts:
#   bob/bio/base/tools/command_line.py
parents 49a62900 75ebdf5b
No related branches found
No related tags found
1 merge request!119Integrate the new bob.extension loading config mechanism
Pipeline #
Showing
with 139 additions and 57 deletions
......@@ -9,7 +9,7 @@ from numpy.testing.decorators import setastest
import bob.db.base
class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.FileDatabase)):
"""This class represents the basic API for database access.
Please use this class as a base class for your database access classes.
Do not forget to call the constructor of this base class in your derived class.
......@@ -90,7 +90,8 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
super(BioDatabase, self).__init__(
original_directory=original_directory,
original_extension=original_extension)
original_extension=original_extension,
**kwargs)
self.name = name
......@@ -100,7 +101,7 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
self.enroller_training_options = enroller_training_options
self.check_existence = check_original_files_for_existence
self._kwargs = kwargs
self._kwargs = {}
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
......@@ -196,18 +197,9 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
if self.original_directory in replacements:
self.original_directory = replacements[self.original_directory]
try:
self._db.original_directory = self.original_directory
except AttributeError:
pass
try:
if self.annotation_directory in replacements:
self.annotation_directory = replacements[self.annotation_directory]
try:
self._db.annotation_directory = self.annotation_directory
except AttributeError:
pass
except AttributeError:
pass
......
......@@ -87,6 +87,8 @@ class ListReader(object):
raise RuntimeError('File %s does not exist.' % (list_file,))
try:
for line in fileinput.input(list_file):
if line.strip().startswith('#'):
continue
parsed_line = re.findall('[\w/(-.)]+', line)
if len(parsed_line):
# perform some sanity checks
......
......@@ -10,6 +10,9 @@ from .. import BioFile
from .models import ListReader
import logging
logger = logging.getLogger('bob.bio.base')
class FileListBioDatabase(ZTBioDatabase):
"""This class provides a user-friendly interface to databases that are given as file lists.
......@@ -127,7 +130,6 @@ class FileListBioDatabase(ZTBioDatabase):
and the given sub-directories and file names (which default to useful values if not given)."""
super(FileListBioDatabase, self).__init__(
filelists_directory=filelists_directory,
name=name,
protocol=protocol,
original_directory=original_directory,
......@@ -135,7 +137,10 @@ class FileListBioDatabase(ZTBioDatabase):
annotation_directory=annotation_directory,
annotation_extension=annotation_extension,
annotation_type=annotation_type,
# extra args for pretty printing
**kwargs)
# extra args for pretty printing
self._kwargs.update(dict(
filelists_directory=filelists_directory,
dev_sub_directory=dev_sub_directory,
eval_sub_directory=eval_sub_directory,
world_filename=world_filename,
......@@ -147,9 +152,10 @@ class FileListBioDatabase(ZTBioDatabase):
tnorm_filename=tnorm_filename,
znorm_filename=znorm_filename,
use_dense_probe_file_list=use_dense_probe_file_list,
# if both probe_filename and scores_filename are given, what kind of list should be used?
# if both probe_filename and scores_filename are given, what kind
# of list should be used?
keep_read_lists_in_memory=keep_read_lists_in_memory,
**kwargs)
))
# self.original_directory = original_directory
# self.original_extension = original_extension
self.bio_file_class = bio_file_class
......@@ -226,10 +232,12 @@ class FileListBioDatabase(ZTBioDatabase):
if group == 'world':
continue
if add_zt_files:
if not self.implements_zt(self.protocol, group):
raise ValueError("ZT score files are requested, but no such files are defined in group %s for protocol %s", group, self.protocol)
files += self.tobjects(group, self.protocol)
files += self.zobjects(group, self.protocol, **self.z_probe_options)
if self.implements_zt(self.protocol, group):
files += self.tobjects(group, self.protocol)
files += self.zobjects(group, self.protocol, **self.z_probe_options)
else:
logger.warn("ZT score files are requested, but no such files are defined in group %s for protocol %s", group, self.protocol)
return self.sort(self._make_bio(files))
......
......@@ -7,14 +7,15 @@ class MultipleExtractor(Extractor):
"""Base class for SequentialExtractor and ParallelExtractor. This class is
not meant to be used directly."""
def get_attributes(self, processors):
@staticmethod
def get_attributes(processors):
requires_training = any(p.requires_training for p in processors)
split_training_data_by_client = any(p.split_training_data_by_client for
p in processors)
min_extractor_file_size = min(p.min_extractor_file_size for p in
processors)
min_feature_file_size = min(
p.min_feature_file_size for p in processors)
min_feature_file_size = min(p.min_feature_file_size for p in
processors)
return (requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size)
......@@ -23,35 +24,59 @@ class MultipleExtractor(Extractor):
return groups
def train_one(self, e, training_data, extractor_file, apply=False):
"""Trains one extractor and optionally applies the extractor on the
training data after training.
Parameters
----------
e : :any:`Extractor`
The extractor to train. The extractor should be able to save itself
in an opened hdf5 file.
training_data : [object] or [[object]]
The data to be used for training.
extractor_file : :any:`bob.io.base.HDF5File`
The opened hdf5 file to save the trained extractor inside.
apply : :obj:`bool`, optional
If ``True``, the extractor is applied to the training data after it
is trained and the data is returned.
Returns
-------
None or [object] or [[object]]
Returns ``None`` if ``apply`` is ``False``. Otherwise, returns the
transformed ``training_data``.
"""
if not e.requires_training:
return
# do nothing since e does not require training!
pass
# if any of the extractors require splitting the data, the
# split_training_data_by_client is True.
if e.split_training_data_by_client:
elif e.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [[e(d) for d in datalist]
for datalist in training_data]
# when no extractor needs splitting
elif not self.split_training_data_by_client:
e.train(training_data, extractor_file)
if not apply:
return
training_data = [e(d) for d in training_data]
# when e here wants it flat but the data is split
else:
# make training_data flat
aligned_training_data = [d for datalist in training_data for d in
datalist]
e.train(aligned_training_data, extractor_file)
if not apply:
return
flat_training_data = [d for datalist in training_data for d in
datalist]
e.train(flat_training_data, extractor_file)
if not apply:
return
# prepare the training data for the next extractor
if self.split_training_data_by_client:
training_data = [[e(d) for d in datalist]
for datalist in training_data]
else:
training_data = [e(d) for d in training_data]
return training_data
def load(self, extractor_file):
if not self.requires_training:
return
with HDF5File(extractor_file) as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
......@@ -88,7 +113,7 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
True
"""
def __init__(self, processors):
def __init__(self, processors, **kwargs):
(requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) = \
......@@ -99,15 +124,18 @@ class SequentialExtractor(SequentialProcessor, MultipleExtractor):
requires_training=requires_training,
split_training_data_by_client=split_training_data_by_client,
min_extractor_file_size=min_extractor_file_size,
min_feature_file_size=min_feature_file_size)
min_feature_file_size=min_feature_file_size,
**kwargs)
def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f:
groups = self.get_extractor_groups()
for e, group in zip(self.processors, groups):
for i, (e, group) in enumerate(zip(self.processors, groups)):
apply = i != len(self.processors) - 1
f.create_group(group)
f.cd(group)
training_data = self.train_one(e, training_data, f, apply=True)
training_data = self.train_one(e, training_data, f,
apply=apply)
f.cd('..')
def read_feature(self, feature_file):
......@@ -154,7 +182,7 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor):
[ 1. , 2. , 3. , 0.5, 1. , 1.5]])
"""
def __init__(self, processors):
def __init__(self, processors, **kwargs):
(requires_training, split_training_data_by_client,
min_extractor_file_size, min_feature_file_size) = self.get_attributes(
......@@ -165,7 +193,8 @@ class ParallelExtractor(ParallelProcessor, MultipleExtractor):
requires_training=requires_training,
split_training_data_by_client=split_training_data_by_client,
min_extractor_file_size=min_extractor_file_size,
min_feature_file_size=min_feature_file_size)
min_feature_file_size=min_feature_file_size,
**kwargs)
def train(self, training_data, extractor_file):
with HDF5File(extractor_file, 'w') as f:
......
# Model samples
#Modelsamples
#data/model3_session1_sample1 3 3
data/model3_session1_sample1 3 3
data/model3_session1_sample2 3 3
data/model3_session1_sample3 3 3
......
import numpy
import bob.io.base
import bob.bio.base
from bob.bio.base.extractor import Extractor
......@@ -12,10 +12,10 @@ class DummyExtractor (Extractor):
def train(self, train_data, extractor_file):
assert isinstance(train_data, list)
bob.io.base.save(_data, extractor_file)
bob.bio.base.save(_data, extractor_file)
def load(self, extractor_file):
data = bob.io.base.load(extractor_file)
data = bob.bio.base.load(extractor_file)
assert (_data == data).all()
self.model = True
......
......@@ -136,6 +136,12 @@ def test_query_protocol():
assert len(db.objects(protocol=prot, groups='dev', purposes='probe')) == 9
def test_noztnorm():
db = FileListBioDatabase(os.path.join(os.path.dirname(example_dir),
'example_filelist2'), 'test')
assert len(db.all_files())
def test_query_dense():
db = FileListBioDatabase(example_dir, 'test', use_dense_probe_file_list=True)
......
from functools import partial
import numpy as np
import tempfile
from bob.bio.base.utils.processors import (
SequentialProcessor, ParallelProcessor)
from bob.bio.base.preprocessor import (
SequentialPreprocessor, ParallelPreprocessor, CallablePreprocessor)
from bob.bio.base.extractor import (
SequentialExtractor, ParallelExtractor, CallableExtractor)
from bob.bio.base.test.dummy.extractor import extractor as dummy_extractor
DATA = [0, 1, 2, 3, 4]
PROCESSORS = [partial(np.power, 2), np.mean]
......@@ -37,9 +39,31 @@ def test_preprocessors():
def test_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS]
proc = SequentialExtractor(processors)
proc.load(None)
data = proc(DATA)
assert np.allclose(data, SEQ_DATA)
proc = ParallelExtractor(processors)
proc.load(None)
data = proc(DATA)
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
def test_sequential_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = SequentialExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(DATA)
assert np.allclose(data, SEQ_DATA)
def test_parallel_trainable_extractors():
processors = [CallableExtractor(p) for p in PROCESSORS] + [dummy_extractor]
proc = ParallelExtractor(processors)
with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
proc.train(DATA, f.name)
proc.load(f.name)
data = proc(np.array(DATA))
assert all(np.allclose(x1, x2) for x1, x2 in zip(data, PAR_DATA))
......@@ -123,7 +123,8 @@ def project(algorithm, extractor, groups = None, indices = None, allow_missing_f
if not utils.check_file(projected_file, force,
algorithm.min_projected_file_size):
logger.debug("... Projecting features for file '%s'", feature_file)
logger.debug("... Projecting features for file '%s' (%d/%d)",
feature_file, index_range.index(i)+1, len(index_range))
# create output directory before reading the data file (is sometimes required, when relative directories are specified, especially, including a .. somewhere)
bob.io.base.create_directories_safe(os.path.dirname(projected_file))
# load feature
......@@ -256,7 +257,7 @@ def enroll(algorithm, extractor, compute_zt_norm, indices = None, groups = ['dev
logger.info("- Enrollment: splitting of index range %s", str(indices))
logger.info("- Enrollment: enrolling models of group '%s'", group)
for model_id in model_ids:
for pos, model_id in enumerate(model_ids):
# Path to the model
model_file = fs.model_file(model_id, group)
......@@ -271,7 +272,9 @@ def enroll(algorithm, extractor, compute_zt_norm, indices = None, groups = ['dev
logger.debug("... Skipping model file %s since no feature file could be found", model_file)
continue
logger.debug("... Enrolling model from %d features to file '%s'", len(enroll_files), model_file)
logger.debug("... Enrolling model '%d' from %d feature(s) to "
"file '%s' (%d/%d)", model_id, len(enroll_files), model_file,
pos+1, len(model_ids))
bob.io.base.create_directories_safe(os.path.dirname(model_file))
# load all files into memory
......
......@@ -295,6 +295,9 @@ def parse_config_file(parsers, args, args_dictionary, keywords, skips):
take_from_config_or_command_line(args, config, "sub_directory",
parser.get_default("sub_directory"), is_resource=False)
take_from_config_or_command_line(args, config, "env",
parser.get_default("env"), is_resource=False)
skip_keywords = tuple(['skip_' + k.replace('-', '_') for k in skips])
for keyword in keywords + skip_keywords + ('execute_only',):
......
......@@ -112,7 +112,8 @@ def extract(extractor, preprocessor, groups=None, indices = None, allow_missing_
if not utils.check_file(feature_file, force,
extractor.min_feature_file_size):
logger.debug("... Extracting features for data file '%s'", data_file)
logger.debug("... Extracting features for data file '%s' (%d/%d)",
data_file, index_range.index(i)+1, len(index_range))
# create output directory before reading the data file (is sometimes required, when relative directories are specified, especially, including a .. somewhere)
bob.io.base.create_directories_safe(os.path.dirname(feature_file))
# load data
......
......@@ -67,7 +67,8 @@ def preprocess(preprocessor, groups = None, indices = None, allow_missing_files
# check for existence
if not utils.check_file(preprocessed_data_file, force,
preprocessor.min_preprocessed_file_size):
logger.debug("... Processing original data file '%s'", file_name)
logger.debug("... Processing original data file '%s' (%d/%d)", file_name,
index_range.index(i)+1, len(index_range))
data = preprocessor.read_original_data(file_object, original_directory, original_extension)
# create output directory before reading the data file (is sometimes required, when relative directories are specified, especially, including a .. somewhere)
......
......@@ -131,9 +131,11 @@ def _scores_a(algorithm, reader, model_ids, group, compute_zt_norm, force, write
logger.info("- Scoring: computing scores for group '%s'", group)
# Computes the raw scores for each model
for model_id in model_ids:
for pos, model_id in enumerate(model_ids):
# test if the file is already there
score_file = fs.a_file(model_id, group) if compute_zt_norm else fs.no_norm_file(model_id, group)
logger.debug("... Scoring model '%s' at '%s' (%d/%d)", model_id, score_file,
pos+1, len(model_ids))
if utils.check_file(score_file, force):
logger.warn("Score file '%s' already exists.", score_file)
else:
......@@ -166,9 +168,11 @@ def _scores_b(algorithm, reader, model_ids, group, force, allow_missing_files):
logger.info("- Scoring: computing score matrix B for group '%s'", group)
# Loads the models
for model_id in model_ids:
for pos, model_id in enumerate(model_ids):
# test if the file is already there
score_file = fs.b_file(model_id, group)
logger.debug("... Scoring model '%s' at '%s' (%d/%d)", model_id,
score_file, pos+1, len(model_ids))
if utils.check_file(score_file, force):
logger.warn("Score file '%s' already exists.", score_file)
else:
......@@ -191,9 +195,11 @@ def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files)
logger.info("- Scoring: computing score matrix C for group '%s'", group)
# Computes the raw scores for the T-Norm model
for t_model_id in t_model_ids:
for pos, t_model_id in enumerate(t_model_ids):
# test if the file is already there
score_file = fs.c_file(t_model_id, group)
logger.debug("... Scoring model '%s' at '%s' (%d/%d)", t_model_id,
score_file, pos+1, len(t_model_ids))
if utils.check_file(score_file, force):
logger.warn("Score file '%s' already exists.", score_file)
else:
......@@ -219,9 +225,11 @@ def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files)
z_probe_ids = [z_probe_object.client_id for z_probe_object in z_probe_objects]
# Loads the T-Norm models
for t_model_id in t_model_ids:
for pos, t_model_id in enumerate(t_model_ids):
# test if the file is already there
score_file = fs.d_file(t_model_id, group)
logger.debug("... Scoring model '%s' at '%s' (%d/%d)", t_model_id,
score_file, pos+1, len(t_model_ids))
same_score_file = fs.d_same_value_file(t_model_id, group)
if utils.check_file(score_file, force) and utils.check_file(same_score_file, force):
logger.warn("score files '%s' and '%s' already exist.", score_file, same_score_file)
......
......@@ -26,7 +26,7 @@ class SequentialProcessor(object):
"""
def __init__(self, processors, **kwargs):
super(SequentialProcessor, self).__init__()
super(SequentialProcessor, self).__init__(**kwargs)
self.processors = processors
def __call__(self, data, **kwargs):
......@@ -86,7 +86,7 @@ class ParallelProcessor(object):
"""
def __init__(self, processors, **kwargs):
super(ParallelProcessor, self).__init__()
super(ParallelProcessor, self).__init__(**kwargs)
self.processors = processors
def __call__(self, data, **kwargs):
......
......@@ -124,6 +124,8 @@ The following list files need to be created:
filename client_id
Please note that in all files, the lines starting with any number of white
space and ``#`` will be ignored.
Protocols and File Lists
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment