Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
c3c9e9a1
Commit
c3c9e9a1
authored
Apr 17, 2019
by
Amir MOHAMMADI
Browse files
Make bob tf cache command useful
parent
9921e122
Changes
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/script/cache_dataset.py
View file @
c3c9e9a1
...
...
@@ -9,6 +9,7 @@ import click
import
tensorflow
as
tf
from
bob.extension.scripts.click_helper
import
(
verbosity_option
,
ConfigCommand
,
ResourceOption
,
log_parameters
)
from
bob.bio.base
import
is_argument_available
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -23,21 +24,37 @@ logger = logging.getLogger(__name__)
entry_point_group
=
'bob.learn.tensorflow.input_fn'
,
help
=
'The ``input_fn`` that will return the features and labels. '
'You should call the dataset.cache(...) yourself in the input '
'function.'
)
'function. If the ``input_fn`` accepts a ``cache_only`` argument, '
'it will be given as True.'
)
@
click
.
option
(
'--mode'
,
cls
=
ResourceOption
,
default
=
'train'
,
default
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
show_default
=
True
,
help
=
'One of the tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT} values to be '
'given to the input_fn.'
)
type
=
click
.
Choice
((
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
,
tf
.
estimator
.
ModeKeys
.
PREDICT
)),
help
=
'mode value to be given to the input_fn.'
)
@
verbosity_option
(
cls
=
ResourceOption
)
def
cache_dataset
(
input_fn
,
mode
,
**
kwargs
):
"""Trains networks using Tensorflow estimators."""
log_parameters
(
logger
)
kwargs
=
{}
if
is_argument_available
(
'cache_only'
,
input_fn
):
kwargs
[
'cache_only'
]
=
True
# call the input function manually
with
tf
.
Session
()
as
sess
:
data
=
input_fn
(
mode
)
while
True
:
sess
.
run
(
data
)
data
=
input_fn
(
mode
,
**
kwargs
)
if
isinstance
(
data
,
tf
.
data
.
Dataset
):
iterator
=
data
.
make_initializable_iterator
()
data
=
iterator
.
get_next
()
sess
.
run
(
iterator
.
initializer
)
sess
.
run
(
tf
.
initializers
.
global_variables
())
try
:
while
True
:
sess
.
run
(
data
)
except
tf
.
errors
.
OutOfRangeError
:
click
.
echo
(
"Finished reading the dataset."
)
return
Tiago de Freitas Pereira
@tiago.pereira
·
Apr 23, 2019
Owner
Looks good for me. thanks
Looks good for me. thanks
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment