Skip to content
Snippets Groups Projects
Commit c01b17ba authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'defaultrunconfig' into predict

parents 8c9944fc efa29019
Branches
Tags
2 merge requests!22Add a prediction script,!21Resolve "Adopt to the Estimators API"
Pipeline #
...@@ -17,9 +17,8 @@ from __future__ import absolute_import ...@@ -17,9 +17,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# by default create reproducible nets: # create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import run_config from bob.learn.tensorflow.utils.reproducible import run_config
# utils.reproducible import run_config
import tensorflow as tf import tensorflow as tf
from bob.db.mnist import Database from bob.db.mnist import Database
......
...@@ -63,7 +63,7 @@ def main(argv=None): ...@@ -63,7 +63,7 @@ def main(argv=None):
model_fn = config.model_fn model_fn = config.model_fn
eval_input_fn = config.eval_input_fn eval_input_fn = config.eval_input_fn
eval_interval_secs = getattr(config, 'eval_interval_secs', 300) eval_interval_secs = getattr(config, 'eval_interval_secs', 60)
run_once = getattr(config, 'run_once', False) run_once = getattr(config, 'run_once', False)
run_config = getattr(config, 'run_config', None) run_config = getattr(config, 'run_config', None)
model_params = getattr(config, 'model_params', None) model_params = getattr(config, 'model_params', None)
...@@ -75,7 +75,7 @@ def main(argv=None): ...@@ -75,7 +75,7 @@ def main(argv=None):
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=model_params, config=run_config) params=model_params, config=run_config)
if name: if name:
real_name = name + '_eval' real_name = 'eval_' + name
else: else:
real_name = 'eval' real_name = 'eval'
evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated') evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated')
...@@ -91,7 +91,12 @@ def main(argv=None): ...@@ -91,7 +91,12 @@ def main(argv=None):
continue continue
for checkpoint_path in ckpt.all_model_checkpoint_paths: for checkpoint_path in ckpt.all_model_checkpoint_paths:
global_step = str(get_global_step(checkpoint_path)) try:
global_step = str(get_global_step(checkpoint_path))
except Exception:
print('Failed to find global_step for checkpoint_path {}, '
'skipping ...'.format(checkpoint_path))
continue
if global_step in evaluated_steps: if global_step in evaluated_steps:
continue continue
......
...@@ -66,9 +66,7 @@ def main(argv=None): ...@@ -66,9 +66,7 @@ def main(argv=None):
if run_config is None: if run_config is None:
# by default create reproducible nets: # by default create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import session_conf from bob.learn.tensorflow.utils.reproducible import run_config
run_config = tf.estimator.RunConfig()
run_config.replace(session_config=session_conf)
# Instantiate Estimator # Instantiate Estimator
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment