Trainer.py 11.1 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8
9
10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
15

16
logger = bob.core.log.setup("bob.learn.tensorflow")
17

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
18

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
42
    def __init__(self,
43
44
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
45
46
                 use_gpu=False,
                 loss=None,
47
                 temp_dir="cnn",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48

49
50
51
52
                 # Learning rate
                 base_learning_rate=0.001,
                 weight_decay=0.9,

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
53
                 ###### training options ##########
54
                 convergence_threshold=0.01,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
55
                 iterations=5000,
56
57
                 snapshot=100,
                 prefetch=False,
58
59

                 ## Analizer
60
                 analizer=SoftmaxAnalizer(),
61

62
                 verbosity_level=2):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
63

64
65
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66
67

        self.architecture = architecture
68
        self.optimizer_class = optimizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
69
        self.use_gpu = use_gpu
70
71
72
73
74
        self.loss = loss
        self.temp_dir = temp_dir

        self.base_learning_rate = base_learning_rate
        self.weight_decay = weight_decay
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
75
76
77
78

        self.iterations = iterations
        self.snapshot = snapshot
        self.convergence_threshold = convergence_threshold
79
        self.prefetch = prefetch
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
80

81
82
83
84
85
86
87
88
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.learning_rate = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
89
        self.thread_pool = None
90
91
92
93
94

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

95
96
97
98
99
100
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None

101
102
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
104
105
    def __del__(self):
        tf.reset_default_graph()

106
    def compute_graph(self, data_shuffler, prefetch=False, name=""):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
107
        """
108
109
        Computes the graph for the trainer.

110

111
112
113
        ** Parameters **

            data_shuffler: Data shuffler
114
            prefetch:
115
116
117
118
            name: Name of the graph
        """

        # Defining place holders
119
        if prefetch:
120
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
121
122
123
124
125
126
127

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
128
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
129
130
131
132
133
134
135
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
136
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
137
138
139
140
141
142
143
144
145

        # Creating graphs and defining the loss
        network_graph = self.architecture.compute_graph(feature_batch)
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
146
        Given a data shuffler prepared the dictionary to be injected in the graph
147
148
149
150

        ** Parameters **
            data_shuffler:

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
151
        """
152
153
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
154
155
156
157
158

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

159
160
161
162
163
164
165
166
167
168
    def fit(self, session, step):
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

169
        if self.prefetch:
170
171
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
172
173
174
175
176
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

177
178
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
179

180
    def compute_validation(self,  session, data_shuffler, step):
181
182
183
184
185
186
187
188
189
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
190
        # Opening a new session for validation
191
192
193
194
        self.validation_graph = self.compute_graph(data_shuffler, name="validation")
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

195
196
197
        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

198
199
200
201
202
        summaries = []
        summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

203
204
205
206
207
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """

208
209
210
211
212
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

213
    def start_thread(self, session):
214
215
216
217
218
219
220
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

221
        threads = []
222
223
        for n in range(3):
            t = threading.Thread(target=self.load_and_enqueue, args=(session,))
224
225
226
227
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
228

229
230
    def load_and_enqueue(self, session):
        """
231
        Injecting data in the place holder queue
232
233
234

        **Parameters**
          session: Tensorflow session
235
        """
236

237
        while not self.thread_pool.should_stop():
238
239
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
240

241
242
243
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

244
            session.run(self.enqueue_op, feed_dict=feed_dict)
245
246
247

    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
248
        Train the network
249
250
251
252
253
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
254

255
        # TODO: find an elegant way to provide this as a parameter of the trainer
256
        self.learning_rate = tf.train.exponential_decay(
257
258
259
260
261
262
            self.base_learning_rate,  # Learning rate
            train_data_shuffler.batch_size,
            train_data_shuffler.n_samples,
            self.weight_decay  # Decay step
        )

263
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
264

265
        # Preparing the optimizer
266
267
        self.optimizer_class._learning_rate = self.learning_rate
        self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=tf.Variable(0))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
268

269
        # Train summary
270
        self.summaries_train = self.create_general_summary()
271
272

        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
273
        # Training
274
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
275

276
        with tf.Session() as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
277

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
278
            tf.initialize_all_variables().run()
279
280

            # Start a thread to enqueue data asynchronously, and hide I/O latency.
281
282
283
284
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
285

286
            # TENSOR BOARD SUMMARY
287
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
288
            for step in range(self.iterations):
289
290
291
292
293
294
295

                start = time.time()
                self.fit(session, step)
                end = time.time()
                summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
                self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

296
                if validation_data_shuffler is not None and step % self.snapshot == 0:
297
                    self.compute_validation(session, validation_data_shuffler, step)
298

299
300
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
301
                             validation_data_shuffler, self.architecture, session), step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
302

303
304
305
306
307
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
308

309
310
311
            self.architecture.save(hdf5)
            del hdf5

312
313
314
315
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)
316
317

            session.close()# For some reason the session is not closed after the context manager finishes