Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
59386bb9
Commit
59386bb9
authored
Oct 24, 2017
by
Amir MOHAMMADI
Browse files
Add an example. Remove preprocessor
parent
348c36e5
Pipeline
#13441
passed with stages
in 16 minutes and 13 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/script/predict_bio.py
View file @
59386bb9
#!/usr/bin/env python
"""Saves predictions or embeddings of tf.estimators. This script works with
bob.bio.base databases
and preprocessors
. To use it see the configuration
details below
.
bob.bio.base databases. To use it see the configuration
details below. This
script works with tensorflow 1.4 and above
.
Usage:
%(prog)s [-v...] [-k KEY]... [options] <config_files>...
...
...
@@ -53,26 +53,54 @@ The configuration files should have the following objects totally:
An estimator instance that represents the neural network.
database : :any:`bob.bio.base.database.BioDatabase`
A bio database. Its original_directory must point to the correct path.
preprocessor : :any:`bob.bio.base.preprocessor.Preprocessor`
A preprocessor which loads the data from the database and processes the
data.
groups : [str]
A list of groups to evaluate. Can be any permutation of
``('world', 'dev', 'eval')``.
biofile_to_label : callable
A callable that takes a :any:`bob.bio.base.database.BioFile` and
returns its label as an integer ``label = biofile_to_label(biofile)``.
bio_predict_input_fn : callable
A callable with the signature of
``input_fn = bio_predict_input_fn(generator,output_types, output_shapes)``
``input_fn = bio_predict_input_fn(generator,
output_types, output_shapes)``
The inputs are documented in :any:`tf.data.Dataset.from_generator` and
the output should be a function with no arguments and is passed to
:any:`tf.estimator.Estimator.predict`.
# Optional objects:
read_original_data : callable
A callable with the signature of
``data = read_original_data(biofile, directory, extension)``.
:any:`bob.bio.base.read_original_data` is used by default.
hooks : [:any:`tf.train.SessionRunHook`]
Optional hooks that you may want to attach to the predictions.
An example configuration for a trained model and its evaluation could be::
import tensorflow as tf
# define the database:
from bob.bio.base.test.dummy.database import database
# load the estimator model
estimator = tf.estimator.Estimator(model_fn, model_dir)
groups = ['dev']
# the ``dataset = tf.data.Dataset.from_generator(generator, output_types,
# output_shapes)`` line is mandatory in the function below. You have to
# create it in your configuration file since you want it to be created in
# the same graph as your model.
def bio_predict_input_fn(generator,output_types, output_shapes):
def input_fn():
dataset = tf.data.Dataset.from_generator(generator, output_types,
output_shapes)
# apply all kinds of transformations here, process the data even
# further if you want.
dataset = dataset.prefetch(1)
dataset = dataset.batch(10**3)
images, labels, keys = dataset.make_one_shot_iterator().get_next()
return {'data': images, 'keys': keys}, labels
return input_fn
"""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -98,38 +126,38 @@ def make_output_path(output_dir, key):
return
os
.
path
.
join
(
output_dir
,
key
+
'.hdf5'
)
def
bio_generator
(
database
,
preprocessor
,
groups
,
number_of_parallel_jobs
,
biofile_to_label
,
output_dir
,
multiple_samples
=
False
,
def
bio_generator
(
database
,
groups
,
number_of_parallel_jobs
,
output_dir
,
read_original_data
=
None
,
multiple_samples
=
False
,
force
=
False
):
if
read_original_data
is
None
:
from
bob.bio.base
import
read_original_data
biofiles
=
list
(
database
.
all_files
(
groups
))
if
number_of_parallel_jobs
>
1
:
start
,
end
=
indices
(
biofiles
,
number_of_parallel_jobs
)
biofiles
=
biofiles
[
start
:
end
]
keys
=
(
str
(
f
.
make_path
(
""
,
""
))
for
f
in
biofiles
)
labels
=
(
biofile_to_label
(
f
)
for
f
in
biofiles
)
def
load_data
(
f
,
p
re
processor
,
database
):
data
=
preprocessor
.
read_original_data
(
def
load_data
(
f
,
re
ad_original_data
,
database
):
data
=
read_original_data
(
f
,
database
.
original_directory
,
database
.
original_extension
)
data
=
preprocessor
(
data
,
database
.
annotations
(
f
))
return
data
def
generator
():
for
f
,
label
,
key
in
six
.
moves
.
zip
(
biofiles
,
labels
,
keys
):
for
f
,
key
in
six
.
moves
.
zip
(
biofiles
,
keys
):
outpath
=
make_output_path
(
output_dir
,
key
)
if
not
force
and
os
.
path
.
isfile
(
outpath
):
continue
data
=
load_data
(
f
,
p
re
processor
,
database
)
data
=
load_data
(
f
,
re
ad_original_data
,
database
)
if
multiple_samples
:
for
d
in
data
:
yield
(
d
,
label
,
key
)
yield
(
d
,
-
1
,
key
)
else
:
yield
(
data
,
label
,
key
)
yield
(
data
,
-
1
,
key
)
# load one data to get its type and shape
data
=
load_data
(
biofiles
[
0
],
p
re
processor
,
database
)
data
=
load_data
(
biofiles
[
0
],
re
ad_original_data
,
database
)
if
multiple_samples
:
try
:
data
=
data
[
0
]
...
...
@@ -173,6 +201,7 @@ def main(argv=None):
force
=
get_from_config_or_commandline
(
config
,
'force'
,
args
,
defaults
)
hooks
=
getattr
(
config
,
'hooks'
,
None
)
read_original_data
=
getattr
(
config
,
'read_original_data'
,
None
)
# Sets-up logging
set_verbosity_level
(
logger
,
verbosity
)
...
...
@@ -180,16 +209,14 @@ def main(argv=None):
# required arguments
estimator
=
config
.
estimator
database
=
config
.
database
preprocessor
=
config
.
preprocessor
groups
=
config
.
groups
biofile_to_label
=
config
.
biofile_to_label
bio_predict_input_fn
=
config
.
bio_predict_input_fn
output_dir
=
get_from_config_or_commandline
(
config
,
'output_dir'
,
args
,
defaults
,
False
)
generator
,
output_types
,
output_shapes
=
bio_generator
(
database
,
preprocessor
,
groups
,
number_of_parallel_jobs
,
biofile_to_label
,
output_dir
,
multiple_samples
,
force
)
database
,
groups
,
number_of_parallel_jobs
,
output_dir
,
read_original_data
,
multiple_samples
,
force
)
predict_input_fn
=
bio_predict_input_fn
(
generator
,
output_types
,
output_shapes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment