Commit 103b120e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

improve multisample handling

parent dc19fef7
#!/usr/bin/env python
"""Returns predictions of networks trained with
tf.train.MonitoredTrainingSession
"""Saves predictions or embeddings of tf.estimators. This script works with
bob.bio.base databases and preprocessors. To use it see the configuration
details below.
Usage:
%(prog)s [-v...] [-k KEY]... [options] <config_files>...
......@@ -9,53 +10,69 @@ Usage:
%(prog)s --version
Arguments:
<config_files> The configuration files. The configuration
files are loaded in order and they need to
have several objects inside totally. See
below for explanation.
<config_files> The configuration files. The
configuration files are loaded in order
and they need to have several objects
inside totally. See below for
explanation.
Options:
-h --help Show this help message and exit
--version Show version and exit
-o PATH, --output-dir PATH Name of the output file.
-k KEY, --predict-keys KEY List of `str`, name of the keys to predict.
It is used if the
`EstimatorSpec.predictions` is a `dict`. If
`predict_keys` is used then rest of the
predictions will be filtered from the
dictionary. If `None`, returns all.
--checkpoint-path=<path> Path of a specific checkpoint to predict.
If `None`, the latest checkpoint in
`model_dir` is used.
--multiple-samples If provided, it assumes that the db
interface returns several samples from a
biofile. This option can be used when you
are working with sequences.
-h --help Show this help message and exit
--version Show version and exit
-o PATH, --output-dir PATH Name of the output file.
-k KEY, --predict-keys KEY List of `str`, name of the keys to
predict. It is used if the
`EstimatorSpec.predictions` is a `dict`.
If `predict_keys` is used then rest of
the predictions will be filtered from
the dictionary. If `None`, returns all.
--checkpoint-path=<path> Path of a specific checkpoint to
predict. If `None`, the latest
checkpoint in `model_dir` is used.
--multiple-samples If provided, it assumes that the db
interface returns several samples from a
biofile. This option can be used when
you are working with sequences.
-p N, --number-of-parallel-jobs N The number of parallel jobs that this
script is run in the SGE grid. You
should use this option with
``jman submit -t N``.
-f, --force If provided, it will overwrite the existing
predictions.
-v, --verbose Increases the output verbosity level
should use this option with ``jman
submit -t N``.
-f, --force If provided, it will overwrite the
existing predictions.
-v, --verbose Increases the output verbosity level
The -- options above can also be supplied through configuration files. You just
need to create a variable with a name that replaces ``-`` with ``_``. For
example, use ``multiple_samples`` instead of ``--multiple-samples``.
The configuration files should have the following objects totally:
# Required objects:
estimator
database
preprocessor
groups
biofile_to_label
bio_predict_input_fn
estimator : :any:`tf.estimator.Estimator`
An estimator instance that represents the neural network.
database : :any:`bob.bio.base.database.BioDatabase`
A bio database. Its original_directory must point to the correct path.
preprocessor : :any:`bob.bio.base.preprocessor.Preprocessor`
A preprocessor which loads the data from the database and processes the
data.
groups : [str]
A list of groups to evaluate. Can be any permutation of
``('world', 'dev', 'eval')``.
biofile_to_label : callable
A callable that takes a :any:`bob.bio.base.database.BioFile` and
returns its label as an integer ``label = biofile_to_label(biofile)``.
bio_predict_input_fn : callable
A callable with the signature of
``input_fn = bio_predict_input_fn(generator,output_types, output_shapes)``
The inputs are documented in :any:`tf.data.Dataset.from_generator` and
the output should be a function with no arguments and is passed to
:any:`tf.estimator.Estimator.predict`.
# Optional objects:
hooks
For an example configuration, please see:
bob.learn.tensorflow/bob/learn/tensorflow/examples/mnist/mnist_config.py
hooks : [:any:`tf.train.SessionRunHook`]
Optional hooks that you may want to attach to the predictions.
"""
from __future__ import absolute_import
from __future__ import division
......@@ -71,7 +88,6 @@ import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.bio.base.utils import read_config_file, save
from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.dataset import tf_repeat
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level
......@@ -107,23 +123,22 @@ def bio_generator(database, preprocessor, groups, number_of_parallel_jobs,
continue
data = load_data(f, preprocessor, database)
if multiple_samples:
label = [label for _ in range(len(data))]
key = [key for _ in range(len(data))]
yield (data, label, key)
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
# load one data to get its type and shape
data = load_data(biofiles[0], preprocessor, database)
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
output_types = (data.dtype, tf.int64, tf.string)
data_shape = list(data.shape)
label_shape = tf.TensorShape([])
key_shape = tf.TensorShape([])
if multiple_samples:
data_shape[0] = None
label_shape = tf.TensorShape([None])
key_shape = tf.TensorShape([None])
output_shapes = (tf.TensorShape(data_shape),
label_shape, key_shape)
output_shapes = (data.shape, tf.TensorShape([]), tf.TensorShape([]))
return (generator, output_types, output_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment