Skip to content
Snippets Groups Projects
Commit 3d502f25 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'train_and_evaluate' into 'master'

Add a script to call tf.estimator.train_and_evaluate

See merge request !37
parents 20d5c20d fa3a14bf
Branches
Tags
1 merge request!37Add a script to call tf.estimator.train_and_evaluate
Pipeline #
#!/usr/bin/env python
"""Trains and evaluates a network using Tensorflow estimators.
This script calls the estimator.train_and_evaluate function. Please see:
https://www.tensorflow.org/api_docs/python/tf/estimator/train_and_evaluate
https://www.tensorflow.org/api_docs/python/tf/estimator/TrainSpec
https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec
for more details.
Usage:
%(prog)s [-v...] [options] <config_files>...
%(prog)s --help
%(prog)s --version
Arguments:
<config_files> The configuration files. The
configuration files are loaded in order
and they need to have several objects
inside totally. See below for
explanation.
Options:
-h --help Show this help message and exit
--version Show version and exit
-v, --verbose Increases the output verbosity level
The configuration files should have the following objects totally:
## Required objects:
estimator
train_spec
eval_spec
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# import pkg_resources so that bob imports work properly:
import pkg_resources
import tensorflow as tf
from bob.extension.config import load as read_config_file
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
def main(argv=None):
from docopt import docopt
import os
import sys
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version
defaults = docopt(docs, argv=[""])
args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>']
config = read_config_file(config_files)
# optional arguments
verbosity = get_from_config_or_commandline(
config, 'verbose', args, defaults)
# Sets-up logging
set_verbosity_level(logger, verbosity)
# required arguments
estimator = config.estimator
train_spec = config.train_spec
eval_spec = config.eval_spec
# Train and evaluate
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
if __name__ == '__main__':
main()
......@@ -7,6 +7,7 @@ from bob.io.base.test_utils import datafile
from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords
from bob.learn.tensorflow.script.train_generic import main as train_generic
from bob.learn.tensorflow.script.eval_generic import main as eval_generic
from bob.learn.tensorflow.script.train_and_evaluate import main as train_and_evaluate
dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__)
CONFIG = '''
......@@ -32,6 +33,10 @@ def eval_input_fn():
return batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
batch_size, epochs=1)
# config for train_and_evaluate
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=200)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
def architecture(images):
images = tf.cast(images, tf.float32)
logits = tf.reshape(images, [-1, 92 * 112])
......@@ -115,6 +120,15 @@ def _eval(tmpdir, model_dir, dummy_tfrecord):
eval_generic([config_path])
def _train_and_evaluate(tmpdir, model_dir, dummy_tfrecord):
config = CONFIG % {'model_dir': model_dir,
'tfrecord_filenames': dummy_tfrecord}
config_path = os.path.join(tmpdir, 'train_config.py')
with open(config_path, 'w') as f:
f.write(config)
train_and_evaluate([config_path])
def test_eval():
tmpdir = mkdtemp(prefix='bob_')
try:
......@@ -137,6 +151,10 @@ def test_eval():
assert '1' in doc, doc
assert '200' in doc, doc
print('Train and evaluate a dummy network')
_train_and_evaluate(tmpdir, model_dir, dummy_tfrecord)
finally:
try:
shutil.rmtree(tmpdir)
......
......@@ -52,6 +52,7 @@ setup(
'bob_tf_load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main',
'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main',
'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main',
'bob_tf_train_and_evaluate = bob.learn.tensorflow.script.train_and_evaluate:main',
'bob_tf_predict_generic = bob.learn.tensorflow.script.predict_generic:main',
'bob_tf_predict_bio = bob.learn.tensorflow.script.predict_bio:main',
],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment