From 422a72a6c7de83c2672a8cccc3733f2f61cd81ce Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 13 Oct 2017 11:00:32 +0200 Subject: [PATCH] Create a default runconfig --- bob/learn/tensorflow/examples/mnist/mnist_config.py | 6 ++---- bob/learn/tensorflow/utils/reproducible.py | 3 +++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/bob/learn/tensorflow/examples/mnist/mnist_config.py b/bob/learn/tensorflow/examples/mnist/mnist_config.py index 6227fcd4..bdddc1aa 100644 --- a/bob/learn/tensorflow/examples/mnist/mnist_config.py +++ b/bob/learn/tensorflow/examples/mnist/mnist_config.py @@ -17,16 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from bob.learn.tensorflow.utils.reproducible import session_conf +# create reproducible nets: +from bob.learn.tensorflow.utils.reproducible import run_config import tensorflow as tf model_dir = '/tmp/mnist_model' train_tfrecords = ['/tmp/mnist_data/train.tfrecords'] eval_tfrecords = ['/tmp/mnist_data/test.tfrecords'] -# by default create reproducible nets: -run_config = tf.estimator.RunConfig() -run_config = run_config.replace(session_config=session_conf) run_config = run_config.replace(keep_checkpoint_max=10**3) run_config = run_config.replace(save_checkpoints_secs=60) diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py index 34cb4678..a2f7fb73 100644 --- a/bob/learn/tensorflow/utils/reproducible.py +++ b/bob/learn/tensorflow/utils/reproducible.py @@ -35,3 +35,6 @@ session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, tf.set_random_seed(1234) # sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) # keras.backend.set_session(sess) + +run_config = tf.estimator.RunConfig() +run_config = run_config.replace(session_config=session_conf) -- GitLab