Commit 134b4531 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Don't initialze the estimators

parent e7ca90b0
Pipeline #13531 failed with stages
in 7 minutes and 30 seconds
......@@ -20,16 +20,13 @@ The configuration files should have the following objects totally:
## Required objects:
model_dir
model_fn
estimator
eval_input_fn
## Optional objects:
eval_interval_secs
run_once
run_config
model_params
steps
hooks
name
......@@ -59,33 +56,27 @@ def main(argv=None):
config_files = args['<config_files>']
config = read_config_file(config_files)
model_dir = config.model_dir
model_fn = config.model_fn
estimator = config.estimator
eval_input_fn = config.eval_input_fn
eval_interval_secs = getattr(config, 'eval_interval_secs', 60)
run_once = getattr(config, 'run_once', False)
run_config = getattr(config, 'run_config', None)
model_params = getattr(config, 'model_params', None)
steps = getattr(config, 'steps', None)
hooks = getattr(config, 'hooks', None)
name = getattr(config, 'eval_name', None)
# Instantiate Estimator
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=model_params, config=run_config)
if name:
real_name = 'eval_' + name
else:
real_name = 'eval'
evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated')
evaluated_file = os.path.join(estimator.model_dir, real_name, 'evaluated')
while True:
evaluated_steps = []
if os.path.exists(evaluated_file):
with open(evaluated_file) as f:
evaluated_steps = f.read().split()
evaluated_steps = [line.split()[0] for line in f]
ckpt = tf.train.get_checkpoint_state(nn.model_dir)
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if (not ckpt) or (not ckpt.model_checkpoint_path):
time.sleep(eval_interval_secs)
continue
......@@ -101,7 +92,7 @@ def main(argv=None):
continue
# Evaluate
evaluations = nn.evaluate(
evaluations = estimator.evaluate(
input_fn=eval_input_fn,
steps=steps,
hooks=hooks,
......@@ -109,11 +100,14 @@ def main(argv=None):
name=name,
)
print(', '.join('%s = %s' % (k, v)
for k, v in sorted(six.iteritems(evaluations))))
str_evaluations = ', '.join(
'%s = %s' % (k, v)
for k, v in sorted(six.iteritems(evaluations)))
print(str_evaluations)
sys.stdout.flush()
with open(evaluated_file, 'a') as f:
f.write('{}\n'.format(evaluations['global_step']))
f.write('{} {}\n'.format(
evaluations['global_step'], str_evaluations))
if run_once:
break
time.sleep(eval_interval_secs)
......
#!/usr/bin/env python
"""Trains networks using tf.train.MonitoredTrainingSession
"""Trains networks using Tensorflow estimators.
Usage:
%(prog)s [options] <config_files>...
......@@ -20,14 +20,11 @@ The configuration files should have the following objects totally:
## Required objects:
model_fn
estimator
train_input_fn
## Optional objects:
model_dir
run_config
model_params
hooks
steps
max_steps
......@@ -40,7 +37,6 @@ from __future__ import division
from __future__ import print_function
# import pkg_resources so that bob imports work properly:
import pkg_resources
import tensorflow as tf
from bob.bio.base.utils import read_config_file
......@@ -54,27 +50,16 @@ def main(argv=None):
config_files = args['<config_files>']
config = read_config_file(config_files)
model_fn = config.model_fn
estimator = config.estimator
train_input_fn = config.train_input_fn
model_dir = getattr(config, 'model_dir', None)
run_config = getattr(config, 'run_config', None)
model_params = getattr(config, 'model_params', None)
hooks = getattr(config, 'hooks', None)
steps = getattr(config, 'steps', None)
max_steps = getattr(config, 'max_steps', None)
if run_config is None:
# by default create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import run_config
# Instantiate Estimator
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=model_params, config=run_config)
# Train
nn.train(input_fn=train_input_fn, hooks=hooks, steps=steps,
max_steps=max_steps)
estimator.train(input_fn=train_input_fn, hooks=hooks, steps=steps,
max_steps=max_steps)
if __name__ == '__main__':
......
......@@ -13,6 +13,7 @@ from bob.learn.tensorflow.script.eval_generic import main as eval_generic
dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__)
CONFIG = '''
import tensorflow as tf
from bob.learn.tensorflow.utils.reproducible import run_config
from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, \
batch_data_and_labels
......@@ -78,6 +79,9 @@ def model_fn(features, labels, mode, params, config):
labels=labels, predictions=predictions["classes"])}
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
config=run_config)
'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment