SequenceNetwork.py 3.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 11 Aug 2016 09:39:36 CEST

"""
Class that creates the lenet architecture
"""

import tensorflow as tf
import abc
import six
13
import os
14 15

from collections import OrderedDict
16
from bob.learn.tensorflow.layers import Layer, MaxPooling
17 18 19 20 21 22 23


class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
    """
    Base class to create architectures using TensorFlow
    """

24
    def __init__(self, feature_layer=None):
25 26 27 28
        """
        Base constructor

        **Parameters**
29
        feature_layer:
30 31 32
        """

        self.sequence_net = OrderedDict()
33
        self.feature_layer = feature_layer
34
        self.saver = None
35 36

    def add(self, layer):
37 38 39 40
        """
        Add a layer in the sequence network

        """
41 42 43 44
        if not isinstance(layer, Layer):
            raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
        self.sequence_net[layer.name] = layer

45
    def compute_graph(self, input_data, cut=False):
46 47 48 49 50
        """
        Given the current network, return the Tensorflow graph

         **Parameter**
          input_data:
51
          cut: Name of the layer that you want to cut.
52
        """
53 54 55 56 57

        input_offset = input_data
        for k in self.sequence_net.keys():
            current_layer = self.sequence_net[k]
            current_layer.create_variables(input_offset)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
58
            input_offset = current_layer.get_graph()
59

60 61 62
            if cut and k == self.feature_layer:
                return input_offset

63
        return input_offset
64 65 66 67

    def compute_projection_graph(self, placeholder):
        return self.compute_graph(placeholder, cut=True)

68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    def __call__(self, data, session=None):

        if session is None:
            session = tf.Session()

        batch_size = data.shape[0]
        width = data.shape[1]
        height = data.shape[2]
        channels = data.shape[3]

        # Feeding the placeholder
        feature_placeholder = tf.placeholder(tf.float32, shape=(batch_size, width, height, channels), name="feature")
        feed_dict = {feature_placeholder: data}

        return session.run([self.compute_projection_graph(feature_placeholder)], feed_dict=feed_dict)[0]

    def dump_variables(self):

        variables = {}
        for k in self.sequence_net:
            # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
            if not isinstance(self.sequence_net[k], MaxPooling):
                variables[self.sequence_net[k].W.name] = self.sequence_net[k].W
                variables[self.sequence_net[k].b.name] = self.sequence_net[k].b

        return variables

    def save(self, session, path, step=None):

        if self.saver is None:
            self.saver = tf.train.Saver(self.dump_variables())

        if step is None:
            return self.saver.save(session, os.path.join(path, "model.ckpt"))
        else:
            return self.saver.save(session, os.path.join(path, "model" + str(step) + ".ckpt"))

    def load(self, path, shape, session=None):

        if session is None:
            session = tf.Session()

        # Loading variables
        place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
        self.compute_graph(place_holder)
        tf.initialize_all_variables().run(session=session)

        if self.saver is None:
            self.saver = tf.train.Saver(self.dump_variables())

        self.saver.restore(session, path)