From c3c9e9a1561f74e6b13db81f525fefbaa3cc7337 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 17 Apr 2019 16:29:24 +0200
Subject: [PATCH] Make bob tf cache command useful

---
 bob/learn/tensorflow/script/cache_dataset.py | 31 +++++++++++++++-----
 1 file changed, 24 insertions(+), 7 deletions(-)

diff --git a/bob/learn/tensorflow/script/cache_dataset.py b/bob/learn/tensorflow/script/cache_dataset.py
index 06533e9d..7bb443af 100644
--- a/bob/learn/tensorflow/script/cache_dataset.py
+++ b/bob/learn/tensorflow/script/cache_dataset.py
@@ -9,6 +9,7 @@ import click
 import tensorflow as tf
 from bob.extension.scripts.click_helper import (
     verbosity_option, ConfigCommand, ResourceOption, log_parameters)
+from bob.bio.base import is_argument_available
 
 logger = logging.getLogger(__name__)
 
@@ -23,21 +24,37 @@ logger = logging.getLogger(__name__)
     entry_point_group='bob.learn.tensorflow.input_fn',
     help='The ``input_fn`` that will return the features and labels. '
          'You should call the dataset.cache(...) yourself in the input '
-         'function.')
+         'function. If the ``input_fn`` accepts a ``cache_only`` argument, '
+         'it will be given as True.')
 @click.option(
     '--mode',
     cls=ResourceOption,
-    default='train',
+    default=tf.estimator.ModeKeys.TRAIN,
     show_default=True,
-    help='One of the tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT} values to be '
-    'given to the input_fn.')
+    type=click.Choice((tf.estimator.ModeKeys.TRAIN,
+                       tf.estimator.ModeKeys.EVAL,
+                       tf.estimator.ModeKeys.PREDICT)),
+    help='mode value to be given to the input_fn.')
 @verbosity_option(cls=ResourceOption)
 def cache_dataset(input_fn, mode, **kwargs):
     """Trains networks using Tensorflow estimators."""
     log_parameters(logger)
 
+    kwargs = {}
+    if is_argument_available('cache_only', input_fn):
+        kwargs['cache_only'] = True
+
     # call the input function manually
     with tf.Session() as sess:
-        data = input_fn(mode)
-        while True:
-            sess.run(data)
+        data = input_fn(mode, **kwargs)
+        if isinstance(data, tf.data.Dataset):
+            iterator = data.make_initializable_iterator()
+            data = iterator.get_next()
+            sess.run(iterator.initializer)
+        sess.run(tf.initializers.global_variables())
+        try:
+            while True:
+                sess.run(data)
+        except tf.errors.OutOfRangeError:
+            click.echo("Finished reading the dataset.")
+            return
-- 
GitLab