SequenceNetwork.py 1.23 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 11 Aug 2016 09:39:36 CEST

"""
Class that creates the lenet architecture
"""

from ..util import *
import tensorflow as tf
import abc
import six

from collections import OrderedDict
from bob.learn.tensorflow.layers import *


class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
    """
    Base class to create architectures using TensorFlow
    """

    def __init__(self):
        """
        Base constructor

        **Parameters**
        input: Place Holder
        """

        self.sequence_net = OrderedDict()

    def add(self, layer):
        if not isinstance(layer, Layer):
            raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
        self.sequence_net[layer.name] = layer

    def compute_graph(self, input_data):


        input_offset = input_data
        for k in self.sequence_net.keys():
            print k
            import ipdb;
            ipdb.set_trace();

            current_layer = self.sequence_net[k]
            current_layer.create_variables(input_offset)
            input_offset = current_layer.get_graph

        return input_offset