From c3c9e9a1561f74e6b13db81f525fefbaa3cc7337 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 17 Apr 2019 16:29:24 +0200 Subject: [PATCH] Make bob tf cache command useful --- bob/learn/tensorflow/script/cache_dataset.py | 31 +++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/bob/learn/tensorflow/script/cache_dataset.py b/bob/learn/tensorflow/script/cache_dataset.py index 06533e9d..7bb443af 100644 --- a/bob/learn/tensorflow/script/cache_dataset.py +++ b/bob/learn/tensorflow/script/cache_dataset.py @@ -9,6 +9,7 @@ import click import tensorflow as tf from bob.extension.scripts.click_helper import ( verbosity_option, ConfigCommand, ResourceOption, log_parameters) +from bob.bio.base import is_argument_available logger = logging.getLogger(__name__) @@ -23,21 +24,37 @@ logger = logging.getLogger(__name__) entry_point_group='bob.learn.tensorflow.input_fn', help='The ``input_fn`` that will return the features and labels. ' 'You should call the dataset.cache(...) yourself in the input ' - 'function.') + 'function. If the ``input_fn`` accepts a ``cache_only`` argument, ' + 'it will be given as True.') @click.option( '--mode', cls=ResourceOption, - default='train', + default=tf.estimator.ModeKeys.TRAIN, show_default=True, - help='One of the tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT} values to be ' - 'given to the input_fn.') + type=click.Choice((tf.estimator.ModeKeys.TRAIN, + tf.estimator.ModeKeys.EVAL, + tf.estimator.ModeKeys.PREDICT)), + help='mode value to be given to the input_fn.') @verbosity_option(cls=ResourceOption) def cache_dataset(input_fn, mode, **kwargs): """Trains networks using Tensorflow estimators.""" log_parameters(logger) + kwargs = {} + if is_argument_available('cache_only', input_fn): + kwargs['cache_only'] = True + # call the input function manually with tf.Session() as sess: - data = input_fn(mode) - while True: - sess.run(data) + data = input_fn(mode, **kwargs) + if isinstance(data, tf.data.Dataset): + iterator = data.make_initializable_iterator() + data = iterator.get_next() + sess.run(iterator.initializer) + sess.run(tf.initializers.global_variables()) + try: + while True: + sess.run(data) + except tf.errors.OutOfRangeError: + click.echo("Finished reading the dataset.") + return -- GitLab