diff --git a/bob/learn/tensorflow/examples/mnist/mnist_config.py b/bob/learn/tensorflow/examples/mnist/mnist_config.py index e28209cda2e5012c0890860d785abd1e438e137e..9991e218e4bf454d28fadc2e1c3d6eaa52c89067 100644 --- a/bob/learn/tensorflow/examples/mnist/mnist_config.py +++ b/bob/learn/tensorflow/examples/mnist/mnist_config.py @@ -17,9 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# by default create reproducible nets: +# create reproducible nets: from bob.learn.tensorflow.utils.reproducible import run_config -# utils.reproducible import run_config import tensorflow as tf from bob.db.mnist import Database diff --git a/bob/learn/tensorflow/script/eval_generic.py b/bob/learn/tensorflow/script/eval_generic.py index f29f756707c3c643711fbb6de9062dd3adb60aba..e8432aa3cb946e29a460dbace75b25908032784e 100644 --- a/bob/learn/tensorflow/script/eval_generic.py +++ b/bob/learn/tensorflow/script/eval_generic.py @@ -63,7 +63,7 @@ def main(argv=None): model_fn = config.model_fn eval_input_fn = config.eval_input_fn - eval_interval_secs = getattr(config, 'eval_interval_secs', 300) + eval_interval_secs = getattr(config, 'eval_interval_secs', 60) run_once = getattr(config, 'run_once', False) run_config = getattr(config, 'run_config', None) model_params = getattr(config, 'model_params', None) @@ -75,7 +75,7 @@ def main(argv=None): nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=run_config) if name: - real_name = name + '_eval' + real_name = 'eval_' + name else: real_name = 'eval' evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated') @@ -91,7 +91,12 @@ def main(argv=None): continue for checkpoint_path in ckpt.all_model_checkpoint_paths: - global_step = str(get_global_step(checkpoint_path)) + try: + global_step = str(get_global_step(checkpoint_path)) + except Exception: + print('Failed to find global_step for checkpoint_path {}, ' + 'skipping ...'.format(checkpoint_path)) + continue if global_step in evaluated_steps: continue diff --git a/bob/learn/tensorflow/script/train_generic.py b/bob/learn/tensorflow/script/train_generic.py index 11f7d18a421b8c5ef48196ca254116722f8c5138..e60e4917baf349422a83360220faf2722aaed66f 100644 --- a/bob/learn/tensorflow/script/train_generic.py +++ b/bob/learn/tensorflow/script/train_generic.py @@ -66,9 +66,7 @@ def main(argv=None): if run_config is None: # by default create reproducible nets: - from bob.learn.tensorflow.utils.reproducible import session_conf - run_config = tf.estimator.RunConfig() - run_config.replace(session_config=session_conf) + from bob.learn.tensorflow.utils.reproducible import run_config # Instantiate Estimator nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,