SequenceNetwork.py 1.76 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 11 Aug 2016 09:39:36 CEST

"""
Class that creates the lenet architecture
"""

from ..util import *
import tensorflow as tf
import abc
import six

from collections import OrderedDict
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
16
from bob.learn.tensorflow.layers import Layer
17 18 19 20 21 22 23


class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
    """
    Base class to create architectures using TensorFlow
    """

24
    def __init__(self, feature_layer=None):
25 26 27 28
        """
        Base constructor

        **Parameters**
29
        feature_layer:
30 31 32
        """

        self.sequence_net = OrderedDict()
33
        self.feature_layer = feature_layer
34 35

    def add(self, layer):
36 37 38 39
        """
        Add a layer in the sequence network

        """
40 41 42 43
        if not isinstance(layer, Layer):
            raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
        self.sequence_net[layer.name] = layer

44
    def compute_graph(self, input_data, cut=False):
45 46 47 48 49 50 51
        """
        Given the current network, return the Tensorflow graph

         **Parameter**
          input_data:
          cut:
        """
52 53 54 55 56

        input_offset = input_data
        for k in self.sequence_net.keys():
            current_layer = self.sequence_net[k]
            current_layer.create_variables(input_offset)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57
            input_offset = current_layer.get_graph()
58

59 60 61
            if cut and k == self.feature_layer:
                return input_offset

62
        return input_offset
63 64 65 66 67 68 69

    def compute_projection_graph(self, placeholder):
        return self.compute_graph(placeholder, cut=True)

    def __call__(self, feed_dict, session):
        #placeholder
        return session.run([self.graph], feed_dict=feed_dict)[0]