SequenceNetwork.py 5.97 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 11 Aug 2016 09:39:36 CEST

"""
Class that creates the lenet architecture
"""

import tensorflow as tf
import abc
import six
13
import os
14 15

from collections import OrderedDict
16
from bob.learn.tensorflow.layers import Layer, MaxPooling
17 18 19 20 21 22 23


class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
    """
    Base class to create architectures using TensorFlow
    """

24
    def __init__(self, feature_layer=None):
25 26 27 28
        """
        Base constructor

        **Parameters**
29
        feature_layer:
30 31 32
        """

        self.sequence_net = OrderedDict()
33
        self.feature_layer = feature_layer
34 35 36
        self.input_divide = 1.
        self.input_subtract = 0.
        #self.saver = None
37 38

    def add(self, layer):
39 40 41 42
        """
        Add a layer in the sequence network

        """
43 44 45 46
        if not isinstance(layer, Layer):
            raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
        self.sequence_net[layer.name] = layer

47
    def compute_graph(self, input_data, cut=False):
48 49 50 51 52
        """
        Given the current network, return the Tensorflow graph

         **Parameter**
          input_data:
53
          cut: Name of the layer that you want to cut.
54
        """
55 56 57 58 59

        input_offset = input_data
        for k in self.sequence_net.keys():
            current_layer = self.sequence_net[k]
            current_layer.create_variables(input_offset)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
60
            input_offset = current_layer.get_graph()
61

62 63 64
            if cut and k == self.feature_layer:
                return input_offset

65
        return input_offset
66 67 68 69

    def compute_projection_graph(self, placeholder):
        return self.compute_graph(placeholder, cut=True)

70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    def __call__(self, data, session=None):

        if session is None:
            session = tf.Session()

        batch_size = data.shape[0]
        width = data.shape[1]
        height = data.shape[2]
        channels = data.shape[3]

        # Feeding the placeholder
        feature_placeholder = tf.placeholder(tf.float32, shape=(batch_size, width, height, channels), name="feature")
        feed_dict = {feature_placeholder: data}

        return session.run([self.compute_projection_graph(feature_placeholder)], feed_dict=feed_dict)[0]

    def dump_variables(self):

        variables = {}
        for k in self.sequence_net:
            # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
            if not isinstance(self.sequence_net[k], MaxPooling):
                variables[self.sequence_net[k].W.name] = self.sequence_net[k].W
                variables[self.sequence_net[k].b.name] = self.sequence_net[k].b

        return variables

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
    def save(self, hdf5, step=None):
        """
        Save the state of the network in HDF5 format

        :param session:
        :param hdf5:
        :param step:
        :return:
        """

        # Directory that stores the tensorflow variables
        hdf5.create_group('/tensor_flow')
        hdf5.cd('/tensor_flow')

        if step is not None:
            group_name = '/step_{0}'.format(step)
            hdf5.create_group(group_name)
            hdf5.cd(group_name)

        # Iterating the variables of the model
        for v in self.dump_variables().keys():
            hdf5.set(v, self.dump_variables()[v].eval())

        hdf5.cd('..')
        if step is not None:
            hdf5.cd('..')

        hdf5.set('input_divide', self.input_divide)
        hdf5.set('input_subtract', self.input_subtract)

    def load(self, hdf5, shape, session=None):

        if session is None:
            session = tf.Session()

        # Loading the normalization parameters
        self.input_divide = hdf5.read('input_divide')
        self.input_subtract = hdf5.read('input_subtract')

        # Loading variables
        place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
        self.compute_graph(place_holder)
        tf.initialize_all_variables().run(session=session)

        hdf5.cd('/tensor_flow')
        for k in self.sequence_net:
            # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
            if not isinstance(self.sequence_net[k], MaxPooling):
                #self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name))
                self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
                session.run(self.sequence_net[k].W)
                self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
                session.run(self.sequence_net[k].b)


        #if self.saver is None:
        #    variables = self.dump_variables()
        #    variables['input_divide'] = self.input_divide
        #    variables['input_subtract'] = self.input_subtract
        #    self.saver = tf.train.Saver(variables)
        #self.saver.restore(session, path)



    """
162 163 164
    def save(self, session, path, step=None):

        if self.saver is None:
165 166 167 168 169 170
            variables = self.dump_variables()
            variables['mean'] = tf.Variable(10.0)
            #import ipdb; ipdb.set_trace()

            tf.initialize_all_variables().run()
            self.saver = tf.train.Saver(variables)
171 172 173 174 175 176

        if step is None:
            return self.saver.save(session, os.path.join(path, "model.ckpt"))
        else:
            return self.saver.save(session, os.path.join(path, "model" + str(step) + ".ckpt"))

177

178 179 180 181 182 183 184 185 186 187 188
    def load(self, path, shape, session=None):

        if session is None:
            session = tf.Session()

        # Loading variables
        place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
        self.compute_graph(place_holder)
        tf.initialize_all_variables().run(session=session)

        if self.saver is None:
189 190 191 192
            variables = self.dump_variables()
            variables['input_divide'] = self.input_divide
            variables['input_subtract'] = self.input_subtract
            self.saver = tf.train.Saver(variables)
193 194

        self.saver.restore(session, path)
195
    """