Layer.py 1.46 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1 2 3 4 5
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 17:38 CEST

6
import tensorflow as tf
7 8
from bob.learn.tensorflow.initialization import Xavier
from bob.learn.tensorflow.initialization import Constant
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
9 10 11 12 13 14 15 16


class Layer(object):

    """
    Layer base class
    """

17 18 19 20
    def __init__(self, name,
                 activation=None,
                 weights_initialization=Xavier(),
                 bias_initialization=Constant(),
21
                 use_gpu=False):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
22 23 24 25
        """
        Base constructor

        **Parameters**
26 27
          name: Name of the layer
          activation: Tensorflow activation operation (https://www.tensorflow.org/versions/r0.10/api_docs/python/nn.html)
28 29
          weights_initialization: Initialization for the weights
          bias_initialization: Initialization for the biases
30 31
          use_gpu: I think this is not necessary to explain
          seed: Initialization seed set in Tensor flow
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32
        """
33
        self.name = name
34 35
        self.weights_initialization = weights_initialization
        self.bias_initialization = bias_initialization
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
36 37
        self.use_gpu = use_gpu

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
38 39
        self.input_layer = None
        self.activation = activation
40

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
41
    def create_variables(self, input_layer):
42 43
        NotImplementedError("Please implement this function in derived classes")

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44 45
    def get_graph(self):
        NotImplementedError("Please implement this function in derived classes")