Trainer.py 9.78 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import logging
logger = logging.getLogger("bob.learn.tensorflow")
import tensorflow as tf
from ..network import SequenceNetwork
10
import threading
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
11
import numpy
12
13
import os
import bob.io.base
14
import bob.core
15
from ..analyzers import SoftmaxAnalizer
16
from tensorflow.core.framework import summary_pb2
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
17

18
logger = bob.core.log.setup("bob.learn.tensorflow")
19

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
20
21
22
class Trainer(object):

    def __init__(self,
23
24
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
25
26
                 use_gpu=False,
                 loss=None,
27
                 temp_dir="cnn",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
28

29
30
31
32
                 # Learning rate
                 base_learning_rate=0.001,
                 weight_decay=0.9,

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
33
                 ###### training options ##########
34
                 convergence_threshold=0.01,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35
                 iterations=5000,
36
37
                 snapshot=100,
                 prefetch=False,
38
39
40
41
42

                 ## Analizer
                 analizer = SoftmaxAnalizer(),


43
                 verbosity_level=2):
44
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
45

46
47
48
49
50
51
52
53
54
55
56
57
        **Parameters**
          architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
          optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
          use_gpu: Use GPUs in the training
          loss: Loss
          temp_dir:
          iterations:
          snapshot:
          convergence_threshold:
        """
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
58
59

        self.architecture = architecture
60
        self.optimizer_class = optimizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
61
        self.use_gpu = use_gpu
62
63
64
65
66
        self.loss = loss
        self.temp_dir = temp_dir

        self.base_learning_rate = base_learning_rate
        self.weight_decay = weight_decay
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67
68
69
70

        self.iterations = iterations
        self.snapshot = snapshot
        self.convergence_threshold = convergence_threshold
71
        self.prefetch = prefetch
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
72

73
74
75
76
77
78
79
80
81
82
83
84
85
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.learning_rate = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

86
87
88
89
90
91
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None

92
93
94
        bob.core.log.set_verbosity_level(logger, verbosity_level)

    def compute_graph(self, data_shuffler, name=""):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
95
        """
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
            name: Name of the graph
        """

        # Defining place holders
        if self.prefetch:
            placeholder_data, placeholder_labels = data_shuffler.get_placeholders_forprefetch(name=name)

            #if validation_data_shuffler is not None:
            #    validation_placeholder_data, validation_placeholder_labels = \
            #        validation_data_shuffler.get_placeholders(name="validation")

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
118
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
            feature_batch, label_batch = data_shuffler.get_placeholders(name=name)

        # Creating graphs and defining the loss
        network_graph = self.architecture.compute_graph(feature_batch)
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
        Computes the feed_dict for the graph

        ** Parameters **

            data_shuffler:

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
142
        """
143
144
145
146
147
148
149
150
151
        data, labels = data_shuffler.get_batch()
        data_placeholder, label_placeholder = data_shuffler.get_placeholders()

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

    def __fit(self, session, step):
        if self.prefetch:
152
153
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
154
155
156
157
158
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

159
160
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    def __compute_validation(self, session, data_shuffler, step):

        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

        self.validation_graph = self.compute_graph(data_shuffler, name="validation")
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

        summaries = []
        summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

    def __create_general_summary(self):
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

182
    def start_thread(self, session):
183
184
        threads = []
        for n in range(1):
185
            t = threading.Thread(target=self.load_and_enqueue, args=(session, ))
186
187
188
189
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
190

191
192
    def load_and_enqueue(self, session):
        """
193
        Injecting data in the place holder queue
194
        """
195

196
197
198
        while not self.thread_pool.should_stop():
            train_data, train_labels = self.train_data_shuffler.get_batch()
            train_placeholder_data, train_placeholder_labels = self.train_data_shuffler.get_placeholders()
199

200
201
202
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

203
            session.run(self.enqueue_op, feed_dict=feed_dict)
204
205
206
207
208
209
210
211
212
213

    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
        Do the loop forward --> backward --|
                      ^--------------------|
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
214

215
        # TODO: find an elegant way to provide this as a parameter of the trainer
216
        self.learning_rate = tf.train.exponential_decay(
217
218
219
220
221
222
            self.base_learning_rate,  # Learning rate
            train_data_shuffler.batch_size,
            train_data_shuffler.n_samples,
            self.weight_decay  # Decay step
        )

223
        self.training_graph = self.compute_graph(train_data_shuffler, name="train")
224

225
        # Preparing the optimizer
226
227
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=tf.Variable(0))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
228

229
        # Train summary
230
231
232
        self.summaries_train = self.__create_general_summary()

        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
233
        # Training
234
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
235

236
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
237

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
238
            tf.initialize_all_variables().run()
239
240

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
241
242
243
244
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
245

246
            # TENSOR BOARD SUMMARY
247
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
248

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
249
            for step in range(self.iterations):
250
                self.__fit(session, step)
251
                if validation_data_shuffler is not None and step % self.snapshot == 0:
252
253
                    self.__compute_validation(session, validation_data_shuffler, step)

254
255
256
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
                            validation_data_shuffler, self.architecture, session), step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
257

258
259
260
261
262
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
263

264
265
266
            self.architecture.save(hdf5)
            del hdf5

267
268
269
270
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)