Commit 9f3b752b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add bob tf predict command

parent 5ad23cfd
......@@ -55,7 +55,7 @@ def non_existing_files(paths, force=False):
yield i
def save_predictions(pool, output_dir, key, pred_buffer, video_container):
def save_predictions(output_dir, key, pred_buffer, video_container):
outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath))
logger.debug("Saving predictions for %s", key)
......@@ -66,7 +66,7 @@ def save_predictions(pool, output_dir, key, pred_buffer, video_container):
data = fc
else:
data = np.mean(pred_buffer[key], axis=0)
pool.apply_async(save, (data, outpath))
save(data, outpath)
@click.command(
......@@ -247,6 +247,68 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
video_container=video_container)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.option(
'--estimator',
'-e',
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.estimator',
help='The estimator that will be evaluated.')
@click.option(
'--predict-input-fn',
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.input_fn',
help='A callable with no arguments which will be used in estimator.predict.')
@click.option(
'--output-dir',
'-o',
required=True,
cls=ResourceOption,
help='The directory to save the predictions.')
@click.option(
'--predict-keys',
'-k',
multiple=True,
default=None,
cls=ResourceOption,
help='List of `str`, name of the keys to predict. It is used if the '
'`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
'then rest of the predictions will be filtered from the dictionary. '
'If `None`, returns all.')
@click.option(
'--checkpoint-path',
'-c',
cls=ResourceOption,
help='Path of a specific checkpoint to predict. If `None`, the '
'latest checkpoint in `model_dir` is used. This can also '
'be a folder which contains a "checkpoint" file where the '
'latest checkpoint from inside this file will be used as '
'checkpoint_path.')
@click.option(
'--hooks',
cls=ResourceOption,
multiple=True,
entry_point_group='bob.learn.tensorflow.hook',
help='List of SessionRunHook subclass instances.')
@click.option(
'--video-container',
'-vc',
is_flag=True,
cls=ResourceOption,
help='If provided, the predictions will be written in FrameContainers from'
' bob.bio.video. You need to install bob.bio.video as well.')
@verbosity_option(cls=ResourceOption)
def predict(estimator, predict_input_fn, output_dir, predict_keys,
checkpoint_path, hooks, video_container, **kwargs):
generic_predict(
estimator, predict_input_fn, output_dir, predict_keys,
checkpoint_path, hooks, video_container)
def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
checkpoint_path=None, hooks=None, video_container=False):
# if the checkpoint_path is a directory, pick the latest checkpoint from
......@@ -273,28 +335,28 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
raise click.ClickException(
'Could not import bob.bio.video. Have you installed it?')
pool = Pool()
pred_buffer = defaultdict(list)
for i, pred in enumerate(predictions):
key = pred['key']
# key is in bytes format in Python 3
if sys.version_info >= (3, ):
key = key.decode(errors='replace')
prob = pred.get('probabilities', pred.get(
'embeddings', pred.get('predictions')))
assert prob is not None
pred_buffer[key].append(prob)
if i == 0:
last_key = key
if last_key == key:
continue
else:
save_predictions(
output_dir, last_key, pred_buffer, video_container)
last_key = key
try:
pred_buffer = defaultdict(list)
for i, pred in enumerate(predictions):
key = pred['key']
# key is in bytes format in Python 3
if sys.version_info >= (3, ):
key = key.decode(errors='replace')
prob = pred.get('probabilities', pred.get(
'embeddings', pred.get('predictions')))
assert prob is not None
pred_buffer[key].append(prob)
if i == 0:
last_key = key
if last_key == key:
continue
else:
save_predictions(
pool, output_dir, last_key, pred_buffer, video_container)
last_key = key
key
# save the final returned key as well:
save_predictions(pool, output_dir, key, pred_buffer, video_container)
finally:
pool.close()
pool.join()
save_predictions(output_dir, key, pred_buffer, video_container)
except UnboundLocalError:
# if the input_fn was empty and hence key is not defined
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment