Xavier.py 1.97 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Mon 05 Sep 2016 16:35 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")

from .Initialization import Initialization
import tensorflow as tf


class Xavier(Initialization):
    """
    Implements the classic and well used Xavier initialization as in

    Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Aistats. Vol. 9. 2010.


    Basically the initialization is Gaussian distribution with mean 0 and variance:

    Var(W) = 1/sqrt(n_{in} + n_{out});
    where n is the total number of parameters for input and output.
    """

    def __init__(self, seed=10., use_gpu=False):

        super(Xavier, self).__init__(seed, use_gpu=use_gpu)

    def __call__(self, shape, name):

        if len(shape) == 4:
            in_out = shape[0] * shape[1] * shape[2] + shape[3]
        else:
            in_out = shape[0] + shape[1]

        import math
38
        stddev = math.sqrt(3.0 / in_out)  # XAVIER INITIALIZER (GAUSSIAN)
39 40 41

        initializer = tf.truncated_normal(shape, stddev=stddev, seed=self.seed)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        try:
            with tf.variable_scope(name):
                if self.use_gpu:
                    with tf.device("/gpu:0"):
                        return tf.get_variable(name, initializer=initializer, dtype=tf.float32)
                else:
                    with tf.device("/cpu"):
                        return tf.get_variable(name, initializer=initializer, dtype=tf.float32)
        except ValueError:
            with tf.variable_scope(name, reuse=True):
                if self.use_gpu:
                    with tf.device("/gpu:0"):
                        return tf.get_variable(name, initializer=initializer, dtype=tf.float32)
                else:
                    with tf.device("/cpu"):
                        return tf.get_variable(name, initializer=initializer, dtype=tf.float32)
58