Skip to content
Snippets Groups Projects
Commit 9f3b752b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add bob tf predict command

parent 5ad23cfd
No related branches found
No related tags found
1 merge request!75A lot of new features
...@@ -55,7 +55,7 @@ def non_existing_files(paths, force=False): ...@@ -55,7 +55,7 @@ def non_existing_files(paths, force=False):
yield i yield i
def save_predictions(pool, output_dir, key, pred_buffer, video_container): def save_predictions(output_dir, key, pred_buffer, video_container):
outpath = make_output_path(output_dir, key) outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath)) create_directories_safe(os.path.dirname(outpath))
logger.debug("Saving predictions for %s", key) logger.debug("Saving predictions for %s", key)
...@@ -66,7 +66,7 @@ def save_predictions(pool, output_dir, key, pred_buffer, video_container): ...@@ -66,7 +66,7 @@ def save_predictions(pool, output_dir, key, pred_buffer, video_container):
data = fc data = fc
else: else:
data = np.mean(pred_buffer[key], axis=0) data = np.mean(pred_buffer[key], axis=0)
pool.apply_async(save, (data, outpath)) save(data, outpath)
@click.command( @click.command(
...@@ -247,6 +247,68 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, ...@@ -247,6 +247,68 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
video_container=video_container) video_container=video_container)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.option(
'--estimator',
'-e',
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.estimator',
help='The estimator that will be evaluated.')
@click.option(
'--predict-input-fn',
required=True,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.input_fn',
help='A callable with no arguments which will be used in estimator.predict.')
@click.option(
'--output-dir',
'-o',
required=True,
cls=ResourceOption,
help='The directory to save the predictions.')
@click.option(
'--predict-keys',
'-k',
multiple=True,
default=None,
cls=ResourceOption,
help='List of `str`, name of the keys to predict. It is used if the '
'`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
'then rest of the predictions will be filtered from the dictionary. '
'If `None`, returns all.')
@click.option(
'--checkpoint-path',
'-c',
cls=ResourceOption,
help='Path of a specific checkpoint to predict. If `None`, the '
'latest checkpoint in `model_dir` is used. This can also '
'be a folder which contains a "checkpoint" file where the '
'latest checkpoint from inside this file will be used as '
'checkpoint_path.')
@click.option(
'--hooks',
cls=ResourceOption,
multiple=True,
entry_point_group='bob.learn.tensorflow.hook',
help='List of SessionRunHook subclass instances.')
@click.option(
'--video-container',
'-vc',
is_flag=True,
cls=ResourceOption,
help='If provided, the predictions will be written in FrameContainers from'
' bob.bio.video. You need to install bob.bio.video as well.')
@verbosity_option(cls=ResourceOption)
def predict(estimator, predict_input_fn, output_dir, predict_keys,
checkpoint_path, hooks, video_container, **kwargs):
generic_predict(
estimator, predict_input_fn, output_dir, predict_keys,
checkpoint_path, hooks, video_container)
def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None, def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
checkpoint_path=None, hooks=None, video_container=False): checkpoint_path=None, hooks=None, video_container=False):
# if the checkpoint_path is a directory, pick the latest checkpoint from # if the checkpoint_path is a directory, pick the latest checkpoint from
...@@ -273,28 +335,28 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None, ...@@ -273,28 +335,28 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
raise click.ClickException( raise click.ClickException(
'Could not import bob.bio.video. Have you installed it?') 'Could not import bob.bio.video. Have you installed it?')
pool = Pool() pred_buffer = defaultdict(list)
for i, pred in enumerate(predictions):
key = pred['key']
# key is in bytes format in Python 3
if sys.version_info >= (3, ):
key = key.decode(errors='replace')
prob = pred.get('probabilities', pred.get(
'embeddings', pred.get('predictions')))
assert prob is not None
pred_buffer[key].append(prob)
if i == 0:
last_key = key
if last_key == key:
continue
else:
save_predictions(
output_dir, last_key, pred_buffer, video_container)
last_key = key
try: try:
pred_buffer = defaultdict(list) key
for i, pred in enumerate(predictions):
key = pred['key']
# key is in bytes format in Python 3
if sys.version_info >= (3, ):
key = key.decode(errors='replace')
prob = pred.get('probabilities', pred.get(
'embeddings', pred.get('predictions')))
assert prob is not None
pred_buffer[key].append(prob)
if i == 0:
last_key = key
if last_key == key:
continue
else:
save_predictions(
pool, output_dir, last_key, pred_buffer, video_container)
last_key = key
# save the final returned key as well: # save the final returned key as well:
save_predictions(pool, output_dir, key, pred_buffer, video_container) save_predictions(output_dir, key, pred_buffer, video_container)
finally: except UnboundLocalError:
pool.close() # if the input_fn was empty and hence key is not defined
pool.join() pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment