Trainer.py 15.8 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8
9
10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from .learning_rate import constant
17

18
logger = bob.core.log.setup("bob.learn.tensorflow")
19

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44
    def __init__(self,
45
46
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
47
48
                 use_gpu=False,
                 loss=None,
49
                 temp_dir="cnn",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
50

51
                 # Learning rate
52
                 learning_rate=constant(),
53

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54
                 ###### training options ##########
55
                 convergence_threshold=0.01,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
56
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57
58
                 snapshot=500,
                 validation_snapshot=100,
59
                 prefetch=False,
60
61

                 ## Analizer
62
                 analizer=SoftmaxAnalizer(),
63

64
65
66
                 ### Pretrained model
                 model_from_file="",

67
                 verbosity_level=2):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
68

69
70
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
71
72

        self.architecture = architecture
73
        self.optimizer_class = optimizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
74
        self.use_gpu = use_gpu
75
76
77
        self.loss = loss
        self.temp_dir = temp_dir

78
        self.learning_rate = learning_rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
79
80
81

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
82
        self.validation_snapshot = validation_snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83
        self.convergence_threshold = convergence_threshold
84
        self.prefetch = prefetch
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
85

86
87
88
89
90
91
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
92
        self.thread_pool = None
93
94
95
96
97

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

98
99
100
101
102
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
        self.global_step = None
104

105
106
        self.model_from_file = model_from_file

107
108
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
109
110
111
    def __del__(self):
        tf.reset_default_graph()

112
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
113
        """
114
115
116
117
118
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
119
            prefetch: Uses prefetch
120
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121
            training: Is it a training graph?
122
123
124
        """

        # Defining place holders
125
        if prefetch:
126
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
127
128
129
130
131
132
133

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
134
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
135
136
137
138
139
140
141
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
142
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
143
144

        # Creating graphs and defining the loss
145
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
146
147
148
149
150
151
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
152
        Given a data shuffler prepared the dictionary to be injected in the graph
153
154
155
156

        ** Parameters **
            data_shuffler:

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
157
        """
158
159
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
160
161
162
163
164

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

165
166
167
168
169
170
171
172
173
174
    def fit(self, session, step):
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

175
        if self.prefetch:
176
177
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
178
179
180
181
182
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

183
184
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
185

186
    def compute_validation(self,  session, data_shuffler, step):
187
188
189
190
191
192
193
194
195
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
196
        # Opening a new session for validation
197
198
199
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

200
201
202
        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
203
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
204
205
206
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

207
208
209
210
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
211
212
213
214
215
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

216
    def start_thread(self, session):
217
218
219
220
221
222
223
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

224
        threads = []
225
226
        for n in range(3):
            t = threading.Thread(target=self.load_and_enqueue, args=(session,))
227
228
229
230
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
231

232
233
    def load_and_enqueue(self, session):
        """
234
        Injecting data in the place holder queue
235
236
237

        **Parameters**
          session: Tensorflow session
238
        """
239

240
        while not self.thread_pool.should_stop():
241
242
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
243

244
245
246
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

247
            session.run(self.enqueue_op, feed_dict=feed_dict)
248

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
249
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
250
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
251
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
252
        """
253
254
255
256
257
258
259
260
261
262
263

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
264
        # Creating validation graph
265
266
267
268
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

    def bootstrap_graphs_fromfile(self, session, train_data_shuffler, validation_data_shuffler):
        """
        Bootstrap all the necessary data from file
        """
        saver = self.architecture.load(session, self.model_from_file)

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

324
325
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
326
        Train the network
327
328
329
330
331
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
332

333
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
334

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
335
336
        config = tf.ConfigProto(log_device_placement=True,
                                gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.333))
337
        config.gpu_options.allow_growth = True
338

339
340
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
341

342
        with tf.Session(config=config) as session:
343

344
345
346
            # Loading a pretrained model
            if self.model_from_file != "":
                logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
347
                saver = self.bootstrap_graphs_fromfile(session, train_data_shuffler, validation_data_shuffler)
348
            else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
349
350
                # Bootstraping all the graphs
                self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
351
352
353
354
355
356
357
358
359
360
361
362

                # TODO: find an elegant way to provide this as a parameter of the trainer
                self.global_step = tf.Variable(0, trainable=False)

                # Preparing the optimizer
                self.optimizer_class._learning_rate = self.learning_rate
                self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
                tf.add_to_collection("optimizer", self.optimizer)
                tf.add_to_collection("learning_rate", self.learning_rate)

                # Train summary
                self.summaries_train = self.create_general_summary()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
363
                tf.add_to_collection("summaries_train", self.summaries_train)
364
365
366
367
368

                tf.initialize_all_variables().run()

                # Original tensorflow saver object
                saver = tf.train.Saver(var_list=tf.all_variables())
369

370
371
372
            if isinstance(train_data_shuffler, OnLineSampling):
                train_data_shuffler.set_feature_extractor(self.architecture, session=session)

373
            # Start a thread to enqueue data asynchronously, and hide I/O latency.
374
375
376
377
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
378

379
            # TENSOR BOARD SUMMARY
380
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
381
            for step in range(self.iterations):
382
383
384
385
386
387
388

                start = time.time()
                self.fit(session, step)
                end = time.time()
                summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
                self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
389
390
                # Running validation
                if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
391
                    self.compute_validation(session, validation_data_shuffler, step)
392

393
394
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
395
                             validation_data_shuffler, self.architecture, session), step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
396

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
397
398
399
                # Taking snapshot
                if step % self.snapshot == 0:
                    logger.info("Taking snapshot")
400
401
                    path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                    self.architecture.save(session, saver, path)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
402

403
404
405
406
407
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
408

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
409
            # Saving the final network
410
411
            path = os.path.join(self.temp_dir, 'model.ckp')
            self.architecture.save(session, saver, path)
412

413
414
415
416
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)