SequenceNetwork.py 1.16 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 11 Aug 2016 09:39:36 CEST

"""
Class that creates the lenet architecture
"""

from ..util import *
import tensorflow as tf
import abc
import six

from collections import OrderedDict
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
16
from bob.learn.tensorflow.layers import Layer
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44


class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
    """
    Base class to create architectures using TensorFlow
    """

    def __init__(self):
        """
        Base constructor

        **Parameters**
        input: Place Holder
        """

        self.sequence_net = OrderedDict()

    def add(self, layer):
        if not isinstance(layer, Layer):
            raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
        self.sequence_net[layer.name] = layer

    def compute_graph(self, input_data):

        input_offset = input_data
        for k in self.sequence_net.keys():
            current_layer = self.sequence_net[k]
            current_layer.create_variables(input_offset)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
45
            input_offset = current_layer.get_graph()
46 47

        return input_offset