Skip to content
Snippets Groups Projects
Commit b83ec5ad authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add a script to cache datasets

parent 1c422bc5
No related branches found
No related tags found
1 merge request!68Several changes
#!/usr/bin/env python
"""Trains networks using Tensorflow estimators.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import click
import tensorflow as tf
from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption, log_parameters)
logger = logging.getLogger(__name__)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.option(
'--input-fn',
'-i',
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.input_fn',
help='The ``input_fn`` that will return the features and labels. '
'You should call the dataset.cache(...) yourself in the input '
'function.')
@click.option(
'--mode',
cls=ResourceOption,
default='train',
show_default=True,
help='One of the tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT} values to be '
'given to the input_fn.')
@verbosity_option(cls=ResourceOption)
def cache_dataset(input_fn, mode, **kwargs):
"""Trains networks using Tensorflow estimators."""
log_parameters(logger)
# call the input function manually
with tf.Session() as sess:
data = input_fn(mode)
while True:
sess.run(data)
......@@ -50,15 +50,16 @@ setup(
# bob tf scripts
'bob.learn.tensorflow.cli': [
'cache_dataset = bob.learn.tensorflow.script.cache_dataset:cache_dataset',
'compute_statistics = bob.learn.tensorflow.script.compute_statistics:compute_statistics',
'db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:db_to_tfrecords',
'describe_tfrecord = bob.learn.tensorflow.script.db_to_tfrecords:describe_tfrecord',
'eval = bob.learn.tensorflow.script.eval:eval',
'trim = bob.learn.tensorflow.script.trim:trim',
'predict_bio = bob.learn.tensorflow.script.predict_bio:predict_bio',
'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer',
'train = bob.learn.tensorflow.script.train:train',
'train_and_evaluate = bob.learn.tensorflow.script.train_and_evaluate:train_and_evaluate',
'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer',
'trim = bob.learn.tensorflow.script.trim:trim',
],
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment