Skip to content
Snippets Groups Projects

Porting to TF2

Merged Tiago de Freitas Pereira requested to merge tf2 into master
1 unresolved thread
9 files
+ 0
691
Compare changes
  • Side-by-side
  • Inline
Files
9
#!/usr/bin/env python
"""Trains networks using Tensorflow estimators.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import click
import tensorflow as tf
from bob.extension.scripts.click_helper import ConfigCommand
from bob.extension.scripts.click_helper import ResourceOption
from bob.extension.scripts.click_helper import log_parameters
from bob.extension.scripts.click_helper import verbosity_option
from ..utils import is_argument_available
logger = logging.getLogger(__name__)
@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand)
@click.option(
"--input-fn",
"-i",
required=True,
cls=ResourceOption,
entry_point_group="bob.learn.tensorflow.input_fn",
help="The ``input_fn`` that will return the features and labels. "
"You should call the dataset.cache(...) yourself in the input "
"function. If the ``input_fn`` accepts a ``cache_only`` argument, "
"it will be given as True.",
)
@click.option(
"--mode",
cls=ResourceOption,
default=tf.estimator.ModeKeys.TRAIN,
show_default=True,
type=click.Choice(
(
tf.estimator.ModeKeys.TRAIN,
tf.estimator.ModeKeys.EVAL,
tf.estimator.ModeKeys.PREDICT,
)
),
help="mode value to be given to the input_fn.",
)
@verbosity_option(cls=ResourceOption)
def cache_dataset(input_fn, mode, **kwargs):
"""Trains networks using Tensorflow estimators."""
log_parameters(logger)
kwargs = {}
if is_argument_available("cache_only", input_fn):
kwargs["cache_only"] = True
logger.info("cache_only as True will be passed to input_fn.")
# call the input function manually
with tf.compat.v1.Session() as sess:
data = input_fn(mode, **kwargs)
if isinstance(data, tf.data.Dataset):
iterator = tf.compat.v1.data.make_initializable_iterator(data)
data = iterator.get_next()
sess.run(iterator.initializer)
sess.run(tf.compat.v1.initializers.global_variables())
try:
while True:
sess.run(data)
except tf.errors.OutOfRangeError:
click.echo("Finished reading the dataset.")
return
Loading