Gaussian.py 1.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Mon 05 Sep 2016 16:35 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")

from .Initialization import Initialization
import tensorflow as tf


class Gaussian(Initialization):
    """
    Implements Gaussian normalization
    """

    def __init__(self, mean=0.,
                 std=1.,
                 seed=10.,
                 use_gpu=False):

        self.mean = mean
        self.std = std
        super(Gaussian, self).__init__(seed, use_gpu=use_gpu)

    def __call__(self, shape, name):

        if len(shape) == 4:
            in_out = shape[0] * shape[1] * shape[2] + shape[3]
        else:
            in_out = shape[0] + shape[1]

        initializer = tf.truncated_normal(shape,
                                          mean=self.mean,
                                          stddev=self.std,
                                          seed=self.seed)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        try:
            with tf.variable_scope(name):
                if self.use_gpu:
                    with tf.device("/gpu:0"):
                        return tf.get_variable(name, initializer=initializer, dtype=tf.float32)
                else:
                    with tf.device("/cpu"):
                        return tf.get_variable(name, initializer=initializer, dtype=tf.float32)

        except ValueError:
            with tf.variable_scope(name, reuse=True):
                if self.use_gpu:
                    with tf.device("/gpu:0"):
                        return tf.get_variable(name, initializer=initializer, dtype=tf.float32)
                else:
                    with tf.device("/cpu"):
                        return tf.get_variable(name, initializer=initializer, dtype=tf.float32)
56