diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py index 9331704ca8b000b29ca252b52163d4f3c3b67249..677994b9beae33667c43dcd2a0cf5069a512ce77 100644 --- a/bob/learn/tensorflow/utils/reproducible.py +++ b/bob/learn/tensorflow/utils/reproducible.py @@ -8,8 +8,13 @@ from tensorflow.core.protobuf import rewriter_config_pb2 def set_seed( - seed=0, python_hash_seed=0, log_device_placement=False, allow_soft_placement=False, - arithmetic_optimization=None, allow_growth=None, + seed=0, + python_hash_seed=0, + log_device_placement=False, + allow_soft_placement=False, + arithmetic_optimization=None, + allow_growth=None, + memory_optimization=None, ): """Sets the seeds in python, numpy, and tensorflow in order to help training reproducible networks. @@ -64,10 +69,13 @@ def set_seed( allow_soft_placement=allow_soft_placement, ) - if arithmetic_optimization == 'off': - off = rewriter_config_pb2.RewriterConfig.OFF + off = rewriter_config_pb2.RewriterConfig.OFF + if arithmetic_optimization == "off": session_config.graph_options.rewrite_options.arithmetic_optimization = off + if memory_optimization == "off": + session_config.graph_options.rewrite_options.memory_optimization = off + if allow_growth is not None: session_config.gpu_options.allow_growth = allow_growth session_config.gpu_options.per_process_gpu_memory_fraction = 0.8