SequenceNetwork.py 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 11 Aug 2016 09:39:36 CEST

"""
Class that creates the lenet architecture
"""

import tensorflow as tf
import abc
import six
13
import os
14 15

from collections import OrderedDict
16
from bob.learn.tensorflow.layers import Layer, MaxPooling
17 18 19 20 21 22 23


class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
    """
    Base class to create architectures using TensorFlow
    """

24 25
    def __init__(self,
                 default_feature_layer=None):
26 27 28 29
        """
        Base constructor

        **Parameters**
30
        feature_layer:
31 32 33
        """

        self.sequence_net = OrderedDict()
34
        self.default_feature_layer = default_feature_layer
35 36 37
        self.input_divide = 1.
        self.input_subtract = 0.
        #self.saver = None
38 39

    def add(self, layer):
40 41 42 43
        """
        Add a layer in the sequence network

        """
44 45 46 47
        if not isinstance(layer, Layer):
            raise ValueError("Input `layer` must be an instance of `bob.learn.tensorflow.layers.Layer`")
        self.sequence_net[layer.name] = layer

48
    def compute_graph(self, input_data, feature_layer=None):
49 50 51 52 53
        """
        Given the current network, return the Tensorflow graph

         **Parameter**
          input_data:
54
          cut: Name of the layer that you want to cut.
55
        """
56 57 58 59 60

        input_offset = input_data
        for k in self.sequence_net.keys():
            current_layer = self.sequence_net[k]
            current_layer.create_variables(input_offset)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
61
            input_offset = current_layer.get_graph()
62

63
            if feature_layer is not None and k == feature_layer:
64 65
                return input_offset

66
        return input_offset
67 68

    def compute_projection_graph(self, placeholder):
69
        return self.compute_graph(placeholder)
70

71
    def __call__(self, data, session=None, feature_layer=None):
72 73 74 75 76

        if session is None:
            session = tf.Session()

        # Feeding the placeholder
77
        feature_placeholder = tf.placeholder(tf.float32, shape=data.shape, name="feature")
78 79
        feed_dict = {feature_placeholder: data}

80 81 82 83
        if feature_layer is None:
            feature_layer = self.default_feature_layer

        return session.run([self.compute_graph(feature_placeholder, feature_layer)], feed_dict=feed_dict)[0]
84 85 86 87 88 89 90 91 92 93 94 95

    def dump_variables(self):

        variables = {}
        for k in self.sequence_net:
            # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
            if not isinstance(self.sequence_net[k], MaxPooling):
                variables[self.sequence_net[k].W.name] = self.sequence_net[k].W
                variables[self.sequence_net[k].b.name] = self.sequence_net[k].b

        return variables

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    def save(self, hdf5, step=None):
        """
        Save the state of the network in HDF5 format
        """

        # Directory that stores the tensorflow variables
        hdf5.create_group('/tensor_flow')
        hdf5.cd('/tensor_flow')

        if step is not None:
            group_name = '/step_{0}'.format(step)
            hdf5.create_group(group_name)
            hdf5.cd(group_name)

        # Iterating the variables of the model
        for v in self.dump_variables().keys():
            hdf5.set(v, self.dump_variables()[v].eval())

        hdf5.cd('..')
        if step is not None:
            hdf5.cd('..')

        hdf5.set('input_divide', self.input_divide)
        hdf5.set('input_subtract', self.input_subtract)

    def load(self, hdf5, shape, session=None):

        if session is None:
            session = tf.Session()

        # Loading the normalization parameters
        self.input_divide = hdf5.read('input_divide')
        self.input_subtract = hdf5.read('input_subtract')

        # Loading variables
        place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
        self.compute_graph(place_holder)
        tf.initialize_all_variables().run(session=session)

        hdf5.cd('/tensor_flow')
        for k in self.sequence_net:
            # TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
            if not isinstance(self.sequence_net[k], MaxPooling):
                #self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name))
                self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
                session.run(self.sequence_net[k].W)
                self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
                session.run(self.sequence_net[k].b)


        #if self.saver is None:
        #    variables = self.dump_variables()
        #    variables['input_divide'] = self.input_divide
        #    variables['input_subtract'] = self.input_subtract
        #    self.saver = tf.train.Saver(variables)
        #self.saver.restore(session, path)



    """
156 157 158
    def save(self, session, path, step=None):

        if self.saver is None:
159 160 161 162 163 164
            variables = self.dump_variables()
            variables['mean'] = tf.Variable(10.0)
            #import ipdb; ipdb.set_trace()

            tf.initialize_all_variables().run()
            self.saver = tf.train.Saver(variables)
165 166 167 168 169 170

        if step is None:
            return self.saver.save(session, os.path.join(path, "model.ckpt"))
        else:
            return self.saver.save(session, os.path.join(path, "model" + str(step) + ".ckpt"))

171

172 173 174 175 176 177 178 179 180 181 182
    def load(self, path, shape, session=None):

        if session is None:
            session = tf.Session()

        # Loading variables
        place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
        self.compute_graph(place_holder)
        tf.initialize_all_variables().run(session=session)

        if self.saver is None:
183 184 185 186
            variables = self.dump_variables()
            variables['input_divide'] = self.input_divide
            variables['input_subtract'] = self.input_subtract
            self.saver = tf.train.Saver(variables)
187 188

        self.saver.restore(session, path)
189
    """