Trainer.py 10.3 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
10
import threading
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
import numpy
12
13
import os
import bob.io.base
14
import bob.core
15
from tensorflow.core.framework import summary_pb2
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
16

17
logger = bob.core.log.setup("bob.learn.tensorflow")
18

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
19
20
21
class Trainer(object):

    def __init__(self,
22
23
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
24
25
                 use_gpu=False,
                 loss=None,
26
                 temp_dir="cnn",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
27

28
29
30
31
                 # Learning rate
                 base_learning_rate=0.001,
                 weight_decay=0.9,

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32
                 ###### training options ##########
33
                 convergence_threshold=0.01,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
34
                 iterations=5000,
35
36
37
                 snapshot=100,
                 prefetch=False,
                 verbosity_level=2):
38
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
39

40
41
42
43
44
45
46
47
48
49
50
51
        **Parameters**
          architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
          optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
          use_gpu: Use GPUs in the training
          loss: Loss
          temp_dir:
          iterations:
          snapshot:
          convergence_threshold:
        """
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
52
53

        self.architecture = architecture
54
        self.optimizer_class = optimizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
55
        self.use_gpu = use_gpu
56
57
58
59
60
        self.loss = loss
        self.temp_dir = temp_dir

        self.base_learning_rate = base_learning_rate
        self.weight_decay = weight_decay
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
61
62
63
64

        self.iterations = iterations
        self.snapshot = snapshot
        self.convergence_threshold = convergence_threshold
65
        self.prefetch = prefetch
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.learning_rate = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

        bob.core.log.set_verbosity_level(logger, verbosity_level)

    def compute_graph(self, data_shuffler, name=""):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83
        """
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
            name: Name of the graph
        """

        # Defining place holders
        if self.prefetch:
            placeholder_data, placeholder_labels = data_shuffler.get_placeholders_forprefetch(name=name)

            #if validation_data_shuffler is not None:
            #    validation_placeholder_data, validation_placeholder_labels = \
            #        validation_data_shuffler.get_placeholders(name="validation")

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
            enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
            feature_batch, label_batch = data_shuffler.get_placeholders(name=name)

        # Creating graphs and defining the loss
        network_graph = self.architecture.compute_graph(feature_batch)
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
        Computes the feed_dict for the graph

        ** Parameters **

            data_shuffler:

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
130
        """
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        data, labels = data_shuffler.get_batch()
        data_placeholder, label_placeholder = data_shuffler.get_placeholders()

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

    def __fit(self, session, step):
        if self.prefetch:
            raise ValueError("Prefetch not implemented for such trainer")
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

            logger.info("Loss training set step={0} = {1}".format(step, l))
            self.train_summary_writter.add_summary(summary, step)

    def __compute_validation(self, session, data_shuffler, step):

        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

        self.validation_graph = self.compute_graph(data_shuffler, name="validation")
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

        summaries = []
        summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

    def __create_general_summary(self):
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()


    """
    def start_thread(self):
        threads = []
        for n in range(1):
            t = threading.Thread(target=self.load_and_enqueue)
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
179

180

181
182
    def load_and_enqueue(self):
        Injecting data in the place holder queue
183
184


185
186
187
188
        #while not thread_pool.should_stop():
        #for i in range(self.iterations):
        while not thread_pool.should_stop():
            train_data, train_labels = train_data_shuffler.get_batch()
189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

            session.run(enqueue_op, feed_dict=feed_dict)

    """

    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
206

207
        # TODO: find an elegant way to provide this as a parameter of the trainer
208
        self.learning_rate = tf.train.exponential_decay(
209
210
211
212
213
214
            self.base_learning_rate,  # Learning rate
            train_data_shuffler.batch_size,
            train_data_shuffler.n_samples,
            self.weight_decay  # Decay step
        )

215
        self.training_graph = self.compute_graph(train_data_shuffler, name="train")
216

217
        # Preparing the optimizer
218
219
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=tf.Variable(0))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
220

221
        # Train summary
222
223
224
        self.summaries_train = self.__create_general_summary()

        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
225
        # Training
226
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
227

228
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
229

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
230
            tf.initialize_all_variables().run()
231
232

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
233
234
235
            #thread_pool = tf.train.Coordinator()
            #tf.train.start_queue_runners(coord=thread_pool)
            #threads = start_thread()
236

237
            # TENSOR BOARD SUMMARY
238
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
239

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
240
241
            for step in range(self.iterations):

242
                self.__fit(session, step)
243
                if validation_data_shuffler is not None and step % self.snapshot == 0:
244
245
                    self.__compute_validation(session, validation_data_shuffler, step)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
246

247
248
249
250
                #    validation_data, validation_labels = validation_data_shuffler.get_batch()

                #    feed_dict = {validation_placeholder_data: validation_data,
                #                 validation_placeholder_labels: validation_labels}
251

252
253
254
                    #l, predictions = session.run([loss_validation, validation_prediction, ], feed_dict=feed_dict)
                    #l, summary = session.run([loss_validation, merged_validation], feed_dict=feed_dict)
                    #import ipdb; ipdb.set_trace();
255
256
257
258
                #    l = session.run(loss_validation, feed_dict=feed_dict)
                #    summaries = []
                #    summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
                #    validation_writer.add_summary(summary_pb2.Summary(value=summaries), step)
259
260
261
262
263

                    #l = session.run([loss_validation], feed_dict=feed_dict)
                    #accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]
                    #validation_writer.add_summary(summary, step)
                    #print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
264
                #    print "Step {0}. Loss = {1}".format(step, l)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
265

266
267
268
269
270
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
271

272
273
274
            self.architecture.save(hdf5)
            del hdf5

275
276


277
            # now they should definetely stop
278
279
            #thread_pool.request_stop()
            #thread_pool.join(threads)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
280