From 422a72a6c7de83c2672a8cccc3733f2f61cd81ce Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 13 Oct 2017 11:00:32 +0200
Subject: [PATCH] Create a default runconfig

---
 bob/learn/tensorflow/examples/mnist/mnist_config.py | 6 ++----
 bob/learn/tensorflow/utils/reproducible.py          | 3 +++
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/bob/learn/tensorflow/examples/mnist/mnist_config.py b/bob/learn/tensorflow/examples/mnist/mnist_config.py
index 6227fcd4..bdddc1aa 100644
--- a/bob/learn/tensorflow/examples/mnist/mnist_config.py
+++ b/bob/learn/tensorflow/examples/mnist/mnist_config.py
@@ -17,16 +17,14 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from bob.learn.tensorflow.utils.reproducible import session_conf
+# create reproducible nets:
+from bob.learn.tensorflow.utils.reproducible import run_config
 import tensorflow as tf
 
 model_dir = '/tmp/mnist_model'
 train_tfrecords = ['/tmp/mnist_data/train.tfrecords']
 eval_tfrecords = ['/tmp/mnist_data/test.tfrecords']
 
-# by default create reproducible nets:
-run_config = tf.estimator.RunConfig()
-run_config = run_config.replace(session_config=session_conf)
 run_config = run_config.replace(keep_checkpoint_max=10**3)
 run_config = run_config.replace(save_checkpoints_secs=60)
 
diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py
index 34cb4678..a2f7fb73 100644
--- a/bob/learn/tensorflow/utils/reproducible.py
+++ b/bob/learn/tensorflow/utils/reproducible.py
@@ -35,3 +35,6 @@ session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
 tf.set_random_seed(1234)
 # sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
 # keras.backend.set_session(sess)
+
+run_config = tf.estimator.RunConfig()
+run_config = run_config.replace(session_config=session_conf)
-- 
GitLab