Commit c3c9e9a1 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Make bob tf cache command useful

parent 9921e122
......@@ -9,6 +9,7 @@ import click
import tensorflow as tf
from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption, log_parameters)
from bob.bio.base import is_argument_available
logger = logging.getLogger(__name__)
......@@ -23,21 +24,37 @@ logger = logging.getLogger(__name__)
entry_point_group='bob.learn.tensorflow.input_fn',
help='The ``input_fn`` that will return the features and labels. '
'You should call the dataset.cache(...) yourself in the input '
'function.')
'function. If the ``input_fn`` accepts a ``cache_only`` argument, '
'it will be given as True.')
@click.option(
'--mode',
cls=ResourceOption,
default='train',
default=tf.estimator.ModeKeys.TRAIN,
show_default=True,
help='One of the tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT} values to be '
'given to the input_fn.')
type=click.Choice((tf.estimator.ModeKeys.TRAIN,
tf.estimator.ModeKeys.EVAL,
tf.estimator.ModeKeys.PREDICT)),
help='mode value to be given to the input_fn.')
@verbosity_option(cls=ResourceOption)
def cache_dataset(input_fn, mode, **kwargs):
"""Trains networks using Tensorflow estimators."""
log_parameters(logger)
kwargs = {}
if is_argument_available('cache_only', input_fn):
kwargs['cache_only'] = True
# call the input function manually
with tf.Session() as sess:
data = input_fn(mode)
while True:
sess.run(data)
data = input_fn(mode, **kwargs)
if isinstance(data, tf.data.Dataset):
iterator = data.make_initializable_iterator()
data = iterator.get_next()
sess.run(iterator.initializer)
sess.run(tf.initializers.global_variables())
try:
while True:
sess.run(data)
except tf.errors.OutOfRangeError:
click.echo("Finished reading the dataset.")
return
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment