Commit 59386bb9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add an example. Remove preprocessor

parent 348c36e5
Pipeline #13441 passed with stages
in 16 minutes and 13 seconds
#!/usr/bin/env python
"""Saves predictions or embeddings of tf.estimators. This script works with
bob.bio.base databases and preprocessors. To use it see the configuration
details below.
bob.bio.base databases. To use it see the configuration details below. This
script works with tensorflow 1.4 and above.
Usage:
%(prog)s [-v...] [-k KEY]... [options] <config_files>...
......@@ -53,26 +53,54 @@ The configuration files should have the following objects totally:
An estimator instance that represents the neural network.
database : :any:`bob.bio.base.database.BioDatabase`
A bio database. Its original_directory must point to the correct path.
preprocessor : :any:`bob.bio.base.preprocessor.Preprocessor`
A preprocessor which loads the data from the database and processes the
data.
groups : [str]
A list of groups to evaluate. Can be any permutation of
``('world', 'dev', 'eval')``.
biofile_to_label : callable
A callable that takes a :any:`bob.bio.base.database.BioFile` and
returns its label as an integer ``label = biofile_to_label(biofile)``.
bio_predict_input_fn : callable
A callable with the signature of
``input_fn = bio_predict_input_fn(generator,output_types, output_shapes)``
``input_fn = bio_predict_input_fn(generator, output_types, output_shapes)``
The inputs are documented in :any:`tf.data.Dataset.from_generator` and
the output should be a function with no arguments and is passed to
:any:`tf.estimator.Estimator.predict`.
# Optional objects:
read_original_data : callable
A callable with the signature of
``data = read_original_data(biofile, directory, extension)``.
:any:`bob.bio.base.read_original_data` is used by default.
hooks : [:any:`tf.train.SessionRunHook`]
Optional hooks that you may want to attach to the predictions.
An example configuration for a trained model and its evaluation could be::
import tensorflow as tf
# define the database:
from bob.bio.base.test.dummy.database import database
# load the estimator model
estimator = tf.estimator.Estimator(model_fn, model_dir)
groups = ['dev']
# the ``dataset = tf.data.Dataset.from_generator(generator, output_types,
# output_shapes)`` line is mandatory in the function below. You have to
# create it in your configuration file since you want it to be created in
# the same graph as your model.
def bio_predict_input_fn(generator,output_types, output_shapes):
def input_fn():
dataset = tf.data.Dataset.from_generator(generator, output_types,
output_shapes)
# apply all kinds of transformations here, process the data even
# further if you want.
dataset = dataset.prefetch(1)
dataset = dataset.batch(10**3)
images, labels, keys = dataset.make_one_shot_iterator().get_next()
return {'data': images, 'keys': keys}, labels
return input_fn
"""
from __future__ import absolute_import
from __future__ import division
......@@ -98,38 +126,38 @@ def make_output_path(output_dir, key):
return os.path.join(output_dir, key + '.hdf5')
def bio_generator(database, preprocessor, groups, number_of_parallel_jobs,
biofile_to_label, output_dir, multiple_samples=False,
def bio_generator(database, groups, number_of_parallel_jobs, output_dir,
read_original_data=None, multiple_samples=False,
force=False):
if read_original_data is None:
from bob.bio.base import read_original_data
biofiles = list(database.all_files(groups))
if number_of_parallel_jobs > 1:
start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end]
keys = (str(f.make_path("", "")) for f in biofiles)
labels = (biofile_to_label(f) for f in biofiles)
def load_data(f, preprocessor, database):
data = preprocessor.read_original_data(
def load_data(f, read_original_data, database):
data = read_original_data(
f,
database.original_directory,
database.original_extension)
data = preprocessor(data, database.annotations(f))
return data
def generator():
for f, label, key in six.moves.zip(biofiles, labels, keys):
for f, key in six.moves.zip(biofiles, keys):
outpath = make_output_path(output_dir, key)
if not force and os.path.isfile(outpath):
continue
data = load_data(f, preprocessor, database)
data = load_data(f, read_original_data, database)
if multiple_samples:
for d in data:
yield (d, label, key)
yield (d, -1, key)
else:
yield (data, label, key)
yield (data, -1, key)
# load one data to get its type and shape
data = load_data(biofiles[0], preprocessor, database)
data = load_data(biofiles[0], read_original_data, database)
if multiple_samples:
try:
data = data[0]
......@@ -173,6 +201,7 @@ def main(argv=None):
force = get_from_config_or_commandline(
config, 'force', args, defaults)
hooks = getattr(config, 'hooks', None)
read_original_data = getattr(config, 'read_original_data', None)
# Sets-up logging
set_verbosity_level(logger, verbosity)
......@@ -180,16 +209,14 @@ def main(argv=None):
# required arguments
estimator = config.estimator
database = config.database
preprocessor = config.preprocessor
groups = config.groups
biofile_to_label = config.biofile_to_label
bio_predict_input_fn = config.bio_predict_input_fn
output_dir = get_from_config_or_commandline(
config, 'output_dir', args, defaults, False)
generator, output_types, output_shapes = bio_generator(
database, preprocessor, groups, number_of_parallel_jobs,
biofile_to_label, output_dir, multiple_samples, force)
database, groups, number_of_parallel_jobs, output_dir,
read_original_data, multiple_samples, force)
predict_input_fn = bio_predict_input_fn(generator,
output_types, output_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment