Commit 93086e0e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

update train generic script

parent 5c4ecf64
...@@ -3,18 +3,25 @@ ...@@ -3,18 +3,25 @@
"""Trains networks using Tensorflow estimators. """Trains networks using Tensorflow estimators.
Usage: Usage:
%(prog)s [options] <config_files>... %(prog)s [-v...] [options] <config_files>...
%(prog)s --help %(prog)s --help
%(prog)s --version %(prog)s --version
Arguments: Arguments:
<config_files> The configuration files. The configuration files are loaded <config_files> The configuration files. The
in order and they need to have several objects inside configuration files are loaded in order
totally. See below for explanation. and they need to have several objects
inside totally. See below for
explanation.
Options: Options:
-h --help show this help message and exit -h --help Show this help message and exit
--version show version and exit --version Show version and exit
-v, --verbose Increases the output verbosity level
-s N, --steps N The number of steps to train.
-m N, --max-steps N The maximum number of steps to train.
This is a limit for global step which
continues in separate runs.
The configuration files should have the following objects totally: The configuration files should have the following objects totally:
...@@ -26,11 +33,6 @@ The configuration files should have the following objects totally: ...@@ -26,11 +33,6 @@ The configuration files should have the following objects totally:
## Optional objects: ## Optional objects:
hooks hooks
steps
max_steps
For an example configuration, please see:
bob.learn.tensorflow/bob/learn/tensorflow/examples/mnist/mnist_config.py
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -38,6 +40,10 @@ from __future__ import print_function ...@@ -38,6 +40,10 @@ from __future__ import print_function
# import pkg_resources so that bob imports work properly: # import pkg_resources so that bob imports work properly:
import pkg_resources import pkg_resources
from bob.bio.base.utils import read_config_file from bob.bio.base.utils import read_config_file
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
def main(argv=None): def main(argv=None):
...@@ -46,17 +52,27 @@ def main(argv=None): ...@@ -46,17 +52,27 @@ def main(argv=None):
import sys import sys
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version version = pkg_resources.require('bob.learn.tensorflow')[0].version
defaults = docopt(docs, argv=[""])
args = docopt(docs, argv=argv, version=version) args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>'] config_files = args['<config_files>']
config = read_config_file(config_files) config = read_config_file(config_files)
# optional arguments
verbosity = get_from_config_or_commandline(
config, 'verbose', args, defaults)
max_steps = get_from_config_or_commandline(
config, 'max_steps', args, defaults)
steps = get_from_config_or_commandline(
config, 'steps', args, defaults)
hooks = getattr(config, 'hooks', None)
# Sets-up logging
set_verbosity_level(logger, verbosity)
# required arguments
estimator = config.estimator estimator = config.estimator
train_input_fn = config.train_input_fn train_input_fn = config.train_input_fn
hooks = getattr(config, 'hooks', None)
steps = getattr(config, 'steps', None)
max_steps = getattr(config, 'max_steps', None)
# Train # Train
estimator.train(input_fn=train_input_fn, hooks=hooks, steps=steps, estimator.train(input_fn=train_input_fn, hooks=hooks, steps=steps,
max_steps=max_steps) max_steps=max_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment