Skip to content
Snippets Groups Projects
Commit c01b17ba authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'defaultrunconfig' into predict

parents 8c9944fc efa29019
No related branches found
No related tags found
2 merge requests!22Add a prediction script,!21Resolve "Adopt to the Estimators API"
Pipeline #
This commit is part of merge request !22. Comments created here will be created in the context of that merge request.
......@@ -17,9 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# by default create reproducible nets:
# create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import run_config
# utils.reproducible import run_config
import tensorflow as tf
from bob.db.mnist import Database
......
......@@ -63,7 +63,7 @@ def main(argv=None):
model_fn = config.model_fn
eval_input_fn = config.eval_input_fn
eval_interval_secs = getattr(config, 'eval_interval_secs', 300)
eval_interval_secs = getattr(config, 'eval_interval_secs', 60)
run_once = getattr(config, 'run_once', False)
run_config = getattr(config, 'run_config', None)
model_params = getattr(config, 'model_params', None)
......@@ -75,7 +75,7 @@ def main(argv=None):
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=model_params, config=run_config)
if name:
real_name = name + '_eval'
real_name = 'eval_' + name
else:
real_name = 'eval'
evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated')
......@@ -91,7 +91,12 @@ def main(argv=None):
continue
for checkpoint_path in ckpt.all_model_checkpoint_paths:
global_step = str(get_global_step(checkpoint_path))
try:
global_step = str(get_global_step(checkpoint_path))
except Exception:
print('Failed to find global_step for checkpoint_path {}, '
'skipping ...'.format(checkpoint_path))
continue
if global_step in evaluated_steps:
continue
......
......@@ -66,9 +66,7 @@ def main(argv=None):
if run_config is None:
# by default create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import session_conf
run_config = tf.estimator.RunConfig()
run_config.replace(session_config=session_conf)
from bob.learn.tensorflow.utils.reproducible import run_config
# Instantiate Estimator
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment