Trainer.py 16.6 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8
9
10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19
logger = bob.core.log.setup("bob.learn.tensorflow")
20

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
21

22
23
24
25
26
27
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    architecture:
      The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`

    optimizer:
      One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html

    use_gpu: bool
      Use GPUs in the training

    loss: :py:class:`bob.learn.tensorflow.loss.BaseLoss`
      Loss function

    temp_dir: str
      The output directory

    learning_rate: :py:class:`bob.learn.tensorflow.trainers.learningrate`
      Initial learning rate

    convergence_threshold:

    iterations: int
      Maximum number of iterations

    snapshot: int
      Will take a snapshot of the network at every `n` iterations

    prefetch: bool
      Use extra Threads to deal with the I/O

    model_from_file: str
      If you want to use a pretrained model

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

    verbosity_level:
64
65

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67
    def __init__(self,
68
69
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70
71
                 use_gpu=False,
                 loss=None,
72
                 temp_dir="cnn",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
73

74
                 # Learning rate
75
                 learning_rate=constant(),
76

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
77
                 ###### training options ##########
78
                 convergence_threshold=0.01,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
79
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
80
81
                 snapshot=500,
                 validation_snapshot=100,
82
                 prefetch=False,
83
84

                 ## Analizer
85
                 analizer=SoftmaxAnalizer(),
86

87
88
89
                 ### Pretrained model
                 model_from_file="",

90
                 verbosity_level=2):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
91

92
93
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
94
95

        self.architecture = architecture
96
        self.optimizer_class = optimizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
97
        self.use_gpu = use_gpu
98
99
100
        self.loss = loss
        self.temp_dir = temp_dir

101
        self.learning_rate = learning_rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
102
103
104

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
105
        self.validation_snapshot = validation_snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
106
        self.convergence_threshold = convergence_threshold
107
        self.prefetch = prefetch
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
108

109
110
111
112
113
114
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
115
        self.thread_pool = None
116
117
118
119
120

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

121
122
123
124
125
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
126
        self.global_step = None
127

128
        self.model_from_file = model_from_file
129
        self.session = None
130

131
132
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
133
134
135
    def __del__(self):
        tf.reset_default_graph()

136
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
137
        """
138
139
140
141
142
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
143
            prefetch: Uses prefetch
144
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
145
            training: Is it a training graph?
146
147
148
        """

        # Defining place holders
149
        if prefetch:
150
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
151
152
153
154
155
156
157

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
158
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
159
160
161
162
163
164
165
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
166
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
167
168

        # Creating graphs and defining the loss
169
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
170
171
172
173
174
175
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
176
        Given a data shuffler prepared the dictionary to be injected in the graph
177
178
179
180

        ** Parameters **
            data_shuffler:

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
181
        """
182
183
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
184
185
186
187
188

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

189
    def fit(self, step):
190
191
192
193
194
195
196
197
198
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

199
        if self.prefetch:
200
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
201
                                             self.learning_rate, self.summaries_train])
202
203
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
204
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
205
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
206

207
208
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
209

210
    def compute_validation(self, data_shuffler, step):
211
212
213
214
215
216
217
218
219
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
220
        # Opening a new session for validation
221
        feed_dict = self.get_feed_dict(data_shuffler)
222
        l = self.session.run(self.validation_graph, feed_dict=feed_dict)
223

224
        if self.validation_summary_writter is None:
225
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph)
226

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
227
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
228
229
230
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

231
232
233
234
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
235
236
237
238
239
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

240
    def start_thread(self):
241
242
243
244
245
246
247
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

248
        threads = []
249
        for n in range(3):
250
            t = threading.Thread(target=self.load_and_enqueue, args=())
251
252
253
254
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
255

256
    def load_and_enqueue(self):
257
        """
258
        Injecting data in the place holder queue
259
260
261

        **Parameters**
          session: Tensorflow session
262
        """
263

264
        while not self.thread_pool.should_stop():
265
266
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
267

268
269
270
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

271
            self.session.run(self.enqueue_op, feed_dict=feed_dict)
272

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
273
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
274
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
275
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
276
        """
277
278
279
280
281
282
283
284
285
286
287

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
288
        # Creating validation graph
289
290
291
292
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
293
294
295
296
297
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
298
299
300
301
302

         ** Parameters **
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

320
    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
321
322
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
323
324
325
326
327
328
329

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
330
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
331
        saver = self.architecture.load(self.model_from_file, clear_devices=True)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
332
333
334
335
336
337
338
339

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
340
        self.global_step = tf.get_collection("global_step")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
341
342
343
344
345
346
347
348
349
350
351

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
352
353
354
355
356
357

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
358
359
360
361
362
363
364
365
366
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

367
368
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
369
370
371
372
373
374
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
375
376
377
378
379
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
380

381
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
382

383
384
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
385

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
386
        Session.create()
387
388
389
390
391
        self.session = Session.instance().session

        # Loading a pretrained model
        if self.model_from_file != "":
            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
392
            saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
393

394
            start_step = self.global_step.eval(session=self.session)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
395

396
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
397
            start_step = 0
398
399
400
401
            # Bootstraping all the graphs
            self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)

            # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
402
            self.global_step = tf.Variable(0, trainable=False, name="global_step")
403
            tf.add_to_collection("global_step", self.global_step)
404
405
406
407
408
409
410
411
412
413
414

            # Preparing the optimizer
            self.optimizer_class._learning_rate = self.learning_rate
            self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
            tf.add_to_collection("optimizer", self.optimizer)
            tf.add_to_collection("learning_rate", self.learning_rate)

            # Train summary
            self.summaries_train = self.create_general_summary()
            tf.add_to_collection("summaries_train", self.summaries_train)

415
416
            tf.add_to_collection("summaries_train", self.summaries_train)

417
418
419
420
421
422
423
424
425
426
427
            tf.initialize_all_variables().run(session=self.session)

            # Original tensorflow saver object
            saver = tf.train.Saver(var_list=tf.all_variables())

        if isinstance(train_data_shuffler, OnLineSampling):
            train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
        if self.prefetch:
            self.thread_pool = tf.train.Coordinator()
428
429
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
430
431
432

        # TENSOR BOARD SUMMARY
        self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
433
        for step in range(start_step, self.iterations):
434
435

            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
436
            self.fit(step)
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(self.session, validation_data_shuffler, step)

                if self.analizer is not None:
                    self.validation_summary_writter.add_summary(self.analizer(
                         validation_data_shuffler, self.architecture, self.session), step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.architecture.save(saver, path)

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.architecture.save(saver, path)

        if self.prefetch:
            # now they should definetely stop
            self.thread_pool.request_stop()
            self.thread_pool.join(threads)