Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
9f3b752b
Commit
9f3b752b
authored
Apr 17, 2019
by
Amir MOHAMMADI
Browse files
Add bob tf predict command
parent
5ad23cfd
Changes
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/script/predict_bio.py
View file @
9f3b752b
...
...
@@ -55,7 +55,7 @@ def non_existing_files(paths, force=False):
yield
i
def
save_predictions
(
pool
,
output_dir
,
key
,
pred_buffer
,
video_container
):
def
save_predictions
(
output_dir
,
key
,
pred_buffer
,
video_container
):
outpath
=
make_output_path
(
output_dir
,
key
)
create_directories_safe
(
os
.
path
.
dirname
(
outpath
))
logger
.
debug
(
"Saving predictions for %s"
,
key
)
...
...
@@ -66,7 +66,7 @@ def save_predictions(pool, output_dir, key, pred_buffer, video_container):
data
=
fc
else
:
data
=
np
.
mean
(
pred_buffer
[
key
],
axis
=
0
)
pool
.
apply_async
(
save
,
(
data
,
outpath
)
)
save
(
data
,
outpath
)
@
click
.
command
(
...
...
@@ -247,6 +247,68 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
video_container
=
video_container
)
@
click
.
command
(
entry_point_group
=
'bob.learn.tensorflow.config'
,
cls
=
ConfigCommand
)
@
click
.
option
(
'--estimator'
,
'-e'
,
required
=
True
,
cls
=
ResourceOption
,
entry_point_group
=
'bob.learn.tensorflow.estimator'
,
help
=
'The estimator that will be evaluated.'
)
@
click
.
option
(
'--predict-input-fn'
,
required
=
True
,
cls
=
ResourceOption
,
entry_point_group
=
'bob.learn.tensorflow.input_fn'
,
help
=
'A callable with no arguments which will be used in estimator.predict.'
)
@
click
.
option
(
'--output-dir'
,
'-o'
,
required
=
True
,
cls
=
ResourceOption
,
help
=
'The directory to save the predictions.'
)
@
click
.
option
(
'--predict-keys'
,
'-k'
,
multiple
=
True
,
default
=
None
,
cls
=
ResourceOption
,
help
=
'List of `str`, name of the keys to predict. It is used if the '
'`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
'then rest of the predictions will be filtered from the dictionary. '
'If `None`, returns all.'
)
@
click
.
option
(
'--checkpoint-path'
,
'-c'
,
cls
=
ResourceOption
,
help
=
'Path of a specific checkpoint to predict. If `None`, the '
'latest checkpoint in `model_dir` is used. This can also '
'be a folder which contains a "checkpoint" file where the '
'latest checkpoint from inside this file will be used as '
'checkpoint_path.'
)
@
click
.
option
(
'--hooks'
,
cls
=
ResourceOption
,
multiple
=
True
,
entry_point_group
=
'bob.learn.tensorflow.hook'
,
help
=
'List of SessionRunHook subclass instances.'
)
@
click
.
option
(
'--video-container'
,
'-vc'
,
is_flag
=
True
,
cls
=
ResourceOption
,
help
=
'If provided, the predictions will be written in FrameContainers from'
' bob.bio.video. You need to install bob.bio.video as well.'
)
@
verbosity_option
(
cls
=
ResourceOption
)
def
predict
(
estimator
,
predict_input_fn
,
output_dir
,
predict_keys
,
checkpoint_path
,
hooks
,
video_container
,
**
kwargs
):
generic_predict
(
estimator
,
predict_input_fn
,
output_dir
,
predict_keys
,
checkpoint_path
,
hooks
,
video_container
)
def
generic_predict
(
estimator
,
predict_input_fn
,
output_dir
,
predict_keys
=
None
,
checkpoint_path
=
None
,
hooks
=
None
,
video_container
=
False
):
# if the checkpoint_path is a directory, pick the latest checkpoint from
...
...
@@ -273,28 +335,28 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
raise
click
.
ClickException
(
'Could not import bob.bio.video. Have you installed it?'
)
pool
=
Pool
()
pred_buffer
=
defaultdict
(
list
)
for
i
,
pred
in
enumerate
(
predictions
):
key
=
pred
[
'key'
]
# key is in bytes format in Python 3
if
sys
.
version_info
>=
(
3
,
):
key
=
key
.
decode
(
errors
=
'replace'
)
prob
=
pred
.
get
(
'probabilities'
,
pred
.
get
(
'embeddings'
,
pred
.
get
(
'predictions'
)))
assert
prob
is
not
None
pred_buffer
[
key
].
append
(
prob
)
if
i
==
0
:
last_key
=
key
if
last_key
==
key
:
continue
else
:
save_predictions
(
output_dir
,
last_key
,
pred_buffer
,
video_container
)
last_key
=
key
try
:
pred_buffer
=
defaultdict
(
list
)
for
i
,
pred
in
enumerate
(
predictions
):
key
=
pred
[
'key'
]
# key is in bytes format in Python 3
if
sys
.
version_info
>=
(
3
,
):
key
=
key
.
decode
(
errors
=
'replace'
)
prob
=
pred
.
get
(
'probabilities'
,
pred
.
get
(
'embeddings'
,
pred
.
get
(
'predictions'
)))
assert
prob
is
not
None
pred_buffer
[
key
].
append
(
prob
)
if
i
==
0
:
last_key
=
key
if
last_key
==
key
:
continue
else
:
save_predictions
(
pool
,
output_dir
,
last_key
,
pred_buffer
,
video_container
)
last_key
=
key
key
# save the final returned key as well:
save_predictions
(
pool
,
output_dir
,
key
,
pred_buffer
,
video_container
)
finally
:
pool
.
close
()
p
ool
.
join
()
save_predictions
(
output_dir
,
key
,
pred_buffer
,
video_container
)
except
UnboundLocalError
:
# if the input_fn was empty and hence key is not defined
p
ass
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment