reproducible.py 3.38 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1 2
"""Helps training reproducible networks.
"""
3
import os
4
import random as rn
5 6
import numpy as np
import tensorflow as tf
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
7
from tensorflow.core.protobuf import rewriter_config_pb2
8 9


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
10
def set_seed(
11 12 13 14 15 16 17
    seed=0,
    python_hash_seed=0,
    log_device_placement=False,
    allow_soft_placement=False,
    arithmetic_optimization=None,
    allow_growth=None,
    memory_optimization=None,
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
18
):
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    """Sets the seeds in python, numpy, and tensorflow in order to help
    training reproducible networks.

    Parameters
    ----------
    seed : :obj:`int`, optional
        The seed to set.
    python_hash_seed : :obj:`int`, optional
        https://docs.python.org/3.4/using/cmdline.html#envvar-PYTHONHASHSEED
    log_device_placement : :obj:`bool`, optional
        Optionally, log device placement of tensorflow variables.

    Returns
    -------
    :any:`tf.ConfigProto`
        Session config.
    :any:`tf.estimator.RunConfig`
        A run config to help training estimators.

    Notes
    -----
        This functions return a list and its length might change. Please use
        indices to select one of returned values. For example
        ``sess_config, run_config = set_seed()[:2]``.
    """
    # reproducible networks
    # The below is necessary in Python 3.2.3 onwards to
    # have reproducible behavior for certain hash-based operations.
    # See these references for further details:
    # https://docs.python.org/3.4/using/cmdline.html#envvar-PYTHONHASHSEED
    # https://github.com/fchollet/keras/issues/2280#issuecomment-306959926
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
50
    os.environ["PYTHONHASHSEED"] = "{}".format(python_hash_seed)
51 52 53 54 55 56 57 58 59 60 61 62 63 64

    # The below is necessary for starting Numpy generated random numbers
    # in a well-defined initial state.
    np.random.seed(seed)

    # The below is necessary for starting core Python generated random numbers
    # in a well-defined state.
    rn.seed(seed)

    # Force TensorFlow to use single thread.
    # Multiple threads are a potential source of
    # non-reproducible results.
    # For further details, see:
    # https://stackoverflow.com/questions/42022950/which-seeds-have-to-be-set-where-to-realize-100-reproducibility-of-training-res
65 66 67
    session_config = tf.ConfigProto(
        intra_op_parallelism_threads=1,
        inter_op_parallelism_threads=1,
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
68 69 70 71
        log_device_placement=log_device_placement,
        allow_soft_placement=allow_soft_placement,
    )

72 73
    off = rewriter_config_pb2.RewriterConfig.OFF
    if arithmetic_optimization == "off":
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
74
        session_config.graph_options.rewrite_options.arithmetic_optimization = off
75

76 77 78
    if memory_optimization == "off":
        session_config.graph_options.rewrite_options.memory_optimization = off

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
79 80 81 82
    if allow_growth is not None:
        session_config.gpu_options.allow_growth = allow_growth
        session_config.gpu_options.per_process_gpu_memory_fraction = 0.8

83 84 85 86 87 88 89 90 91 92 93 94 95
    # The below tf.set_random_seed() will make random number generation
    # in the TensorFlow backend have a well-defined initial state.
    # For further details, see:
    # https://www.tensorflow.org/api_docs/python/tf/set_random_seed
    tf.set_random_seed(seed)
    # sess = tf.Session(graph=tf.get_default_graph(), config=session_config)
    # keras.backend.set_session(sess)

    run_config = tf.estimator.RunConfig()
    run_config = run_config.replace(session_config=session_config)
    run_config = run_config.replace(tf_random_seed=seed)

    return [session_config, run_config, None, None, None]