Skip to content
Snippets Groups Projects
Commit efa29019 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

adjustments

parent 72ce48c0
No related branches found
No related tags found
2 merge requests!22Add a prediction script,!21Resolve "Adopt to the Estimators API"
......@@ -150,14 +150,14 @@ def main(argv=None):
logger.info('Processing file %d out of %d', i + 1, n_files)
path = f.make_path(data_dir, data_extension)
data = reader(path)
data = reader(path)
if data is None:
if allow_missing_files:
logger.debug("... Processing original data file '{0}' was not successful".format(path))
continue
else:
raise RuntimeError("Preprocessing of file '{0}' was not successful".format(path))
label = file_to_label(f)
if one_file_one_sample:
......
......@@ -63,7 +63,7 @@ def main(argv=None):
model_fn = config.model_fn
eval_input_fn = config.eval_input_fn
eval_interval_secs = getattr(config, 'eval_interval_secs', 300)
eval_interval_secs = getattr(config, 'eval_interval_secs', 60)
run_once = getattr(config, 'run_once', False)
run_config = getattr(config, 'run_config', None)
model_params = getattr(config, 'model_params', None)
......@@ -75,7 +75,7 @@ def main(argv=None):
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
params=model_params, config=run_config)
if name:
real_name = name + '_eval'
real_name = 'eval_' + name
else:
real_name = 'eval'
evaluated_file = os.path.join(nn.model_dir, real_name, 'evaluated')
......
......@@ -66,9 +66,7 @@ def main(argv=None):
if run_config is None:
# by default create reproducible nets:
from bob.learn.tensorflow.utils.reproducible import session_conf
run_config = tf.estimator.RunConfig()
run_config.replace(session_config=session_conf)
from bob.learn.tensorflow.utils.reproducible import run_config
# Instantiate Estimator
nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir,
......
......@@ -32,9 +32,11 @@ session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
# in the TensorFlow backend have a well-defined initial state.
# For further details, see:
# https://www.tensorflow.org/api_docs/python/tf/set_random_seed
tf.set_random_seed(1234)
tf_random_seed = 1234
tf.set_random_seed(tf_random_seed)
# sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
# keras.backend.set_session(sess)
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(session_config=session_conf)
run_config = run_config.replace(tf_random_seed=tf_random_seed)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment