SequenceNetwork.py 1.32 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 11 Aug 2016 09:39:36 CEST

"""
Class that creates the lenet architecture
"""

from ..util import *
import tensorflow as tf
import abc
import six

from collections import OrderedDict
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
16
from bob.learn.tensorflow.layers import Layer
17 18 19 20 21 22 23


class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
    """
    Base class to create architectures using TensorFlow
    """

24
    def __init__(self, feature_layer=None):
25 26 27 28 29 30 31 32
        """
        Base constructor

        **Parameters**
        input: Place Holder
        """

        self.sequence_net = OrderedDict()
33
        self.feature_layer = feature_layer
34 35 36 37 38 39

    def add(self, layer):
        if not isinstance(layer, Layer):
            raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
        self.sequence_net[layer.name] = layer

40
    def compute_graph(self, input_data, cut=False):
41 42 43 44 45

        input_offset = input_data
        for k in self.sequence_net.keys():
            current_layer = self.sequence_net[k]
            current_layer.create_variables(input_offset)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
46
            input_offset = current_layer.get_graph()
47

48 49 50
            if cut and k == self.feature_layer:
                return input_offset

51
        return input_offset