diff --git a/bob/learn/tensorflow/examples/mnist/mnist_config.py b/bob/learn/tensorflow/examples/mnist/mnist_config.py index 6227fcd462822776320db3c4e24a646f3ef2c721..bdddc1aab134df74dd498b0e65f48aa7642803fd 100644 --- a/bob/learn/tensorflow/examples/mnist/mnist_config.py +++ b/bob/learn/tensorflow/examples/mnist/mnist_config.py @@ -17,16 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from bob.learn.tensorflow.utils.reproducible import session_conf +# create reproducible nets: +from bob.learn.tensorflow.utils.reproducible import run_config import tensorflow as tf model_dir = '/tmp/mnist_model' train_tfrecords = ['/tmp/mnist_data/train.tfrecords'] eval_tfrecords = ['/tmp/mnist_data/test.tfrecords'] -# by default create reproducible nets: -run_config = tf.estimator.RunConfig() -run_config = run_config.replace(session_config=session_conf) run_config = run_config.replace(keep_checkpoint_max=10**3) run_config = run_config.replace(save_checkpoints_secs=60) diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py index 34cb4678258c75d40c889580bb30eff42c8f5242..a2f7fb73bd0c67959a59bd6bba6faa1d83067335 100644 --- a/bob/learn/tensorflow/utils/reproducible.py +++ b/bob/learn/tensorflow/utils/reproducible.py @@ -35,3 +35,6 @@ session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, tf.set_random_seed(1234) # sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) # keras.backend.set_session(sess) + +run_config = tf.estimator.RunConfig() +run_config = run_config.replace(session_config=session_conf)