From 9f3b752be42e67885970ac5ec2814da6808a1b01 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 17 Apr 2019 16:24:01 +0200 Subject: [PATCH] Add bob tf predict command --- bob/learn/tensorflow/script/predict_bio.py | 112 ++++++++++++++++----- 1 file changed, 87 insertions(+), 25 deletions(-) diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py index f13081ca..29f8343a 100644 --- a/bob/learn/tensorflow/script/predict_bio.py +++ b/bob/learn/tensorflow/script/predict_bio.py @@ -55,7 +55,7 @@ def non_existing_files(paths, force=False): yield i -def save_predictions(pool, output_dir, key, pred_buffer, video_container): +def save_predictions(output_dir, key, pred_buffer, video_container): outpath = make_output_path(output_dir, key) create_directories_safe(os.path.dirname(outpath)) logger.debug("Saving predictions for %s", key) @@ -66,7 +66,7 @@ def save_predictions(pool, output_dir, key, pred_buffer, video_container): data = fc else: data = np.mean(pred_buffer[key], axis=0) - pool.apply_async(save, (data, outpath)) + save(data, outpath) @click.command( @@ -247,6 +247,68 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, video_container=video_container) +@click.command( + entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) +@click.option( + '--estimator', + '-e', + required=True, + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.estimator', + help='The estimator that will be evaluated.') +@click.option( + '--predict-input-fn', + required=True, + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.input_fn', + help='A callable with no arguments which will be used in estimator.predict.') +@click.option( + '--output-dir', + '-o', + required=True, + cls=ResourceOption, + help='The directory to save the predictions.') +@click.option( + '--predict-keys', + '-k', + multiple=True, + default=None, + cls=ResourceOption, + help='List of `str`, name of the keys to predict. It is used if the ' + '`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used ' + 'then rest of the predictions will be filtered from the dictionary. ' + 'If `None`, returns all.') +@click.option( + '--checkpoint-path', + '-c', + cls=ResourceOption, + help='Path of a specific checkpoint to predict. If `None`, the ' + 'latest checkpoint in `model_dir` is used. This can also ' + 'be a folder which contains a "checkpoint" file where the ' + 'latest checkpoint from inside this file will be used as ' + 'checkpoint_path.') +@click.option( + '--hooks', + cls=ResourceOption, + multiple=True, + entry_point_group='bob.learn.tensorflow.hook', + help='List of SessionRunHook subclass instances.') +@click.option( + '--video-container', + '-vc', + is_flag=True, + cls=ResourceOption, + help='If provided, the predictions will be written in FrameContainers from' + ' bob.bio.video. You need to install bob.bio.video as well.') +@verbosity_option(cls=ResourceOption) +def predict(estimator, predict_input_fn, output_dir, predict_keys, + checkpoint_path, hooks, video_container, **kwargs): + + generic_predict( + estimator, predict_input_fn, output_dir, predict_keys, + checkpoint_path, hooks, video_container) + + def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None, checkpoint_path=None, hooks=None, video_container=False): # if the checkpoint_path is a directory, pick the latest checkpoint from @@ -273,28 +335,28 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None, raise click.ClickException( 'Could not import bob.bio.video. Have you installed it?') - pool = Pool() + pred_buffer = defaultdict(list) + for i, pred in enumerate(predictions): + key = pred['key'] + # key is in bytes format in Python 3 + if sys.version_info >= (3, ): + key = key.decode(errors='replace') + prob = pred.get('probabilities', pred.get( + 'embeddings', pred.get('predictions'))) + assert prob is not None + pred_buffer[key].append(prob) + if i == 0: + last_key = key + if last_key == key: + continue + else: + save_predictions( + output_dir, last_key, pred_buffer, video_container) + last_key = key try: - pred_buffer = defaultdict(list) - for i, pred in enumerate(predictions): - key = pred['key'] - # key is in bytes format in Python 3 - if sys.version_info >= (3, ): - key = key.decode(errors='replace') - prob = pred.get('probabilities', pred.get( - 'embeddings', pred.get('predictions'))) - assert prob is not None - pred_buffer[key].append(prob) - if i == 0: - last_key = key - if last_key == key: - continue - else: - save_predictions( - pool, output_dir, last_key, pred_buffer, video_container) - last_key = key + key # save the final returned key as well: - save_predictions(pool, output_dir, key, pred_buffer, video_container) - finally: - pool.close() - pool.join() + save_predictions(output_dir, key, pred_buffer, video_container) + except UnboundLocalError: + # if the input_fn was empty and hence key is not defined + pass -- GitLab