Trainer.py 16.3 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8
9
10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19
logger = bob.core.log.setup("bob.learn.tensorflow")
20

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
45
    def __init__(self,
46
47
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48
49
                 use_gpu=False,
                 loss=None,
50
                 temp_dir="cnn",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
51

52
                 # Learning rate
53
                 learning_rate=constant(),
54

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
55
                 ###### training options ##########
56
                 convergence_threshold=0.01,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
58
59
                 snapshot=500,
                 validation_snapshot=100,
60
                 prefetch=False,
61
62

                 ## Analizer
63
                 analizer=SoftmaxAnalizer(),
64

65
66
67
                 ### Pretrained model
                 model_from_file="",

68
                 verbosity_level=2):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
69

70
71
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
72
73

        self.architecture = architecture
74
        self.optimizer_class = optimizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
75
        self.use_gpu = use_gpu
76
77
78
        self.loss = loss
        self.temp_dir = temp_dir

79
        self.learning_rate = learning_rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
80
81
82

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83
        self.validation_snapshot = validation_snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
84
        self.convergence_threshold = convergence_threshold
85
        self.prefetch = prefetch
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
86

87
88
89
90
91
92
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
93
        self.thread_pool = None
94
95
96
97
98

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

99
100
101
102
103
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
104
        self.global_step = None
105

106
        self.model_from_file = model_from_file
107
        self.session = None
108

109
110
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
111
112
113
    def __del__(self):
        tf.reset_default_graph()

114
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
115
        """
116
117
118
119
120
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121
            prefetch: Uses prefetch
122
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
123
            training: Is it a training graph?
124
125
126
        """

        # Defining place holders
127
        if prefetch:
128
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
129
130
131
132
133
134
135

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
136
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
137
138
139
140
141
142
143
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
144
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
145
146

        # Creating graphs and defining the loss
147
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
148
149
150
151
152
153
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
154
        Given a data shuffler prepared the dictionary to be injected in the graph
155
156
157
158

        ** Parameters **
            data_shuffler:

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
159
        """
160
161
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
162
163
164
165
166

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

167
    def fit(self, step):
168
169
170
171
172
173
174
175
176
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

177
        if self.prefetch:
178
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
179
                                             self.learning_rate, self.summaries_train])
180
181
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
182
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
183
184
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

185
186
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
187

188
    def compute_validation(self, data_shuffler, step):
189
190
191
192
193
194
195
196
197
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
198
        # Opening a new session for validation
199
        feed_dict = self.get_feed_dict(data_shuffler)
200
        l = self.session.run(self.validation_graph, feed_dict=feed_dict)
201

202
        if self.validation_summary_writter is None:
203
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph)
204

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
205
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
206
207
208
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

209
210
211
212
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
213
214
215
216
217
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

218
    def start_thread(self):
219
220
221
222
223
224
225
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

226
        threads = []
227
        for n in range(3):
228
            t = threading.Thread(target=self.load_and_enqueue, args=())
229
230
231
232
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
233

234
    def load_and_enqueue(self):
235
        """
236
        Injecting data in the place holder queue
237
238
239

        **Parameters**
          session: Tensorflow session
240
        """
241

242
        while not self.thread_pool.should_stop():
243
244
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
245

246
247
248
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

249
            self.session.run(self.enqueue_op, feed_dict=feed_dict)
250

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
251
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
252
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
253
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
254
        """
255
256
257
258
259
260
261
262
263
264
265

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
266
        # Creating validation graph
267
268
269
270
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
271
272
273
274
275
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
276
277
278
279
280

         ** Parameters **
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

298
    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
299
300
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
301
302
303
304
305
306
307

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
308
        """
309
        saver = self.architecture.load(self.session, self.model_from_file)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
329
330
331
332
333
334

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
335
336
337
338
339
340
341
342
343
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

344
345
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
346
347
348
349
350
351
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
352
353
354
355
356
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
357

358
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
359

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
360
361
        config = tf.ConfigProto(log_device_placement=True,
                                gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.333))
362
        config.gpu_options.allow_growth = True
363

364
365
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        #with tf.Session(config=config) as session:

        self.session = Session.instance().session

        # Loading a pretrained model
        if self.model_from_file != "":
            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
            saver = self.bootstrap_graphs_fromfile(self.session, train_data_shuffler, validation_data_shuffler)
        else:
            # Bootstraping all the graphs
            self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)

            # TODO: find an elegant way to provide this as a parameter of the trainer
            self.global_step = tf.Variable(0, trainable=False)

            # Preparing the optimizer
            self.optimizer_class._learning_rate = self.learning_rate
            self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
            tf.add_to_collection("optimizer", self.optimizer)
            tf.add_to_collection("learning_rate", self.learning_rate)

            # Train summary
            self.summaries_train = self.create_general_summary()
            tf.add_to_collection("summaries_train", self.summaries_train)

            tf.initialize_all_variables().run(session=self.session)

            # Original tensorflow saver object
            saver = tf.train.Saver(var_list=tf.all_variables())

        if isinstance(train_data_shuffler, OnLineSampling):
            train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
        if self.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool)
            threads = self.start_thread(self.session)

        # TENSOR BOARD SUMMARY
        self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
        for step in range(self.iterations):

            start = time.time()
            self.fit(self.session, step)
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(self.session, validation_data_shuffler, step)

                if self.analizer is not None:
                    self.validation_summary_writter.add_summary(self.analizer(
                         validation_data_shuffler, self.architecture, self.session), step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.architecture.save(saver, path)

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.architecture.save(saver, path)

        if self.prefetch:
            # now they should definetely stop
            self.thread_pool.request_stop()
            self.thread_pool.join(threads)