From 9f3b752be42e67885970ac5ec2814da6808a1b01 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 17 Apr 2019 16:24:01 +0200
Subject: [PATCH] Add bob tf predict command

---
 bob/learn/tensorflow/script/predict_bio.py | 112 ++++++++++++++++-----
 1 file changed, 87 insertions(+), 25 deletions(-)

diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py
index f13081ca..29f8343a 100644
--- a/bob/learn/tensorflow/script/predict_bio.py
+++ b/bob/learn/tensorflow/script/predict_bio.py
@@ -55,7 +55,7 @@ def non_existing_files(paths, force=False):
             yield i
 
 
-def save_predictions(pool, output_dir, key, pred_buffer, video_container):
+def save_predictions(output_dir, key, pred_buffer, video_container):
     outpath = make_output_path(output_dir, key)
     create_directories_safe(os.path.dirname(outpath))
     logger.debug("Saving predictions for %s", key)
@@ -66,7 +66,7 @@ def save_predictions(pool, output_dir, key, pred_buffer, video_container):
         data = fc
     else:
         data = np.mean(pred_buffer[key], axis=0)
-    pool.apply_async(save, (data, outpath))
+    save(data, outpath)
 
 
 @click.command(
@@ -247,6 +247,68 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
         video_container=video_container)
 
 
+@click.command(
+    entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
+@click.option(
+    '--estimator',
+    '-e',
+    required=True,
+    cls=ResourceOption,
+    entry_point_group='bob.learn.tensorflow.estimator',
+    help='The estimator that will be evaluated.')
+@click.option(
+    '--predict-input-fn',
+    required=True,
+    cls=ResourceOption,
+    entry_point_group='bob.learn.tensorflow.input_fn',
+    help='A callable with no arguments which will be used in estimator.predict.')
+@click.option(
+    '--output-dir',
+    '-o',
+    required=True,
+    cls=ResourceOption,
+    help='The directory to save the predictions.')
+@click.option(
+    '--predict-keys',
+    '-k',
+    multiple=True,
+    default=None,
+    cls=ResourceOption,
+    help='List of `str`, name of the keys to predict. It is used if the '
+    '`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
+    'then rest of the predictions will be filtered from the dictionary. '
+    'If `None`, returns all.')
+@click.option(
+    '--checkpoint-path',
+    '-c',
+    cls=ResourceOption,
+    help='Path of a specific checkpoint to predict. If `None`, the '
+    'latest checkpoint in `model_dir` is used. This can also '
+    'be a folder which contains a "checkpoint" file where the '
+    'latest checkpoint from inside this file will be used as '
+    'checkpoint_path.')
+@click.option(
+    '--hooks',
+    cls=ResourceOption,
+    multiple=True,
+    entry_point_group='bob.learn.tensorflow.hook',
+    help='List of SessionRunHook subclass instances.')
+@click.option(
+    '--video-container',
+    '-vc',
+    is_flag=True,
+    cls=ResourceOption,
+    help='If provided, the predictions will be written in FrameContainers from'
+    ' bob.bio.video. You need to install bob.bio.video as well.')
+@verbosity_option(cls=ResourceOption)
+def predict(estimator, predict_input_fn, output_dir, predict_keys,
+            checkpoint_path, hooks, video_container, **kwargs):
+
+    generic_predict(
+        estimator, predict_input_fn, output_dir, predict_keys,
+        checkpoint_path, hooks, video_container)
+
+
 def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
                     checkpoint_path=None, hooks=None, video_container=False):
     # if the checkpoint_path is a directory, pick the latest checkpoint from
@@ -273,28 +335,28 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None,
             raise click.ClickException(
                 'Could not import bob.bio.video. Have you installed it?')
 
-    pool = Pool()
+    pred_buffer = defaultdict(list)
+    for i, pred in enumerate(predictions):
+        key = pred['key']
+        # key is in bytes format in Python 3
+        if sys.version_info >= (3, ):
+            key = key.decode(errors='replace')
+        prob = pred.get('probabilities', pred.get(
+            'embeddings', pred.get('predictions')))
+        assert prob is not None
+        pred_buffer[key].append(prob)
+        if i == 0:
+            last_key = key
+        if last_key == key:
+            continue
+        else:
+            save_predictions(
+                output_dir, last_key, pred_buffer, video_container)
+            last_key = key
     try:
-        pred_buffer = defaultdict(list)
-        for i, pred in enumerate(predictions):
-            key = pred['key']
-            # key is in bytes format in Python 3
-            if sys.version_info >= (3, ):
-                key = key.decode(errors='replace')
-            prob = pred.get('probabilities', pred.get(
-                'embeddings', pred.get('predictions')))
-            assert prob is not None
-            pred_buffer[key].append(prob)
-            if i == 0:
-                last_key = key
-            if last_key == key:
-                continue
-            else:
-                save_predictions(
-                    pool, output_dir, last_key, pred_buffer, video_container)
-                last_key = key
+        key
         # save the final returned key as well:
-        save_predictions(pool, output_dir, key, pred_buffer, video_container)
-    finally:
-        pool.close()
-        pool.join()
+        save_predictions(output_dir, key, pred_buffer, video_container)
+    except UnboundLocalError:
+        # if the input_fn was empty and hence key is not defined
+        pass
-- 
GitLab