Trainer.py 12.7 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8
9
10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from .learning_rate import constant
17

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
18
19
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,3,0,2"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
20

21
logger = bob.core.log.setup("bob.learn.tensorflow")
22

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
47
    def __init__(self,
48
49
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
50
51
                 use_gpu=False,
                 loss=None,
52
                 temp_dir="cnn",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
53

54
                 # Learning rate
55
                 learning_rate=constant(),
56

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57
                 ###### training options ##########
58
                 convergence_threshold=0.01,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
59
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
60
61
                 snapshot=500,
                 validation_snapshot=100,
62
                 prefetch=False,
63
64

                 ## Analizer
65
                 analizer=SoftmaxAnalizer(),
66

67
68
69
                 ### Pretrained model
                 model_from_file="",

70
                 verbosity_level=2):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
71

72
73
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
74
75

        self.architecture = architecture
76
        self.optimizer_class = optimizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
77
        self.use_gpu = use_gpu
78
79
80
        self.loss = loss
        self.temp_dir = temp_dir

81
82
83
84
        #self.base_learning_rate = base_learning_rate
        self.learning_rate = learning_rate
        #self.weight_decay = weight_decay
        #self.decay_steps = decay_steps
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
85
86
87

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
88
        self.validation_snapshot = validation_snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
89
        self.convergence_threshold = convergence_threshold
90
        self.prefetch = prefetch
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
91

92
93
94
95
96
97
98
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
99
        self.thread_pool = None
100
101
102
103
104

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

105
106
107
108
109
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
110
        self.global_step = None
111

112
113
        self.model_from_file = model_from_file

114
115
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
116
117
118
    def __del__(self):
        tf.reset_default_graph()

119
    def compute_graph(self, data_shuffler, prefetch=False, name=""):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
120
        """
121
122
        Computes the graph for the trainer.

123

124
125
126
        ** Parameters **

            data_shuffler: Data shuffler
127
            prefetch:
128
129
130
131
            name: Name of the graph
        """

        # Defining place holders
132
        if prefetch:
133
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
134
135
136
137
138
139
140

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
141
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
142
143
144
145
146
147
148
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
149
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
150
151
152
153
154
155
156
157
158

        # Creating graphs and defining the loss
        network_graph = self.architecture.compute_graph(feature_batch)
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
159
        Given a data shuffler prepared the dictionary to be injected in the graph
160
161
162
163

        ** Parameters **
            data_shuffler:

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
164
        """
165
166
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
167
168
169
170
171

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

172
173
174
175
176
177
178
179
180
181
    def fit(self, session, step):
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

182
        if self.prefetch:
183
184
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
185
186
187
188
189
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

190
191
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
192

193
    def compute_validation(self,  session, data_shuffler, step):
194
195
196
197
198
199
200
201
202
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
203
        # Opening a new session for validation
204
205
206
207
        self.validation_graph = self.compute_graph(data_shuffler, name="validation")
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

208
209
210
        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

211
212
213
214
215
        summaries = []
        summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

216
217
218
219
220
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """

221
222
223
224
225
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

226
    def start_thread(self, session):
227
228
229
230
231
232
233
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

234
        threads = []
235
236
        for n in range(3):
            t = threading.Thread(target=self.load_and_enqueue, args=(session,))
237
238
239
240
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
241

242
243
    def load_and_enqueue(self, session):
        """
244
        Injecting data in the place holder queue
245
246
247

        **Parameters**
          session: Tensorflow session
248
        """
249

250
        while not self.thread_pool.should_stop():
251
252
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
253

254
255
256
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

257
            session.run(self.enqueue_op, feed_dict=feed_dict)
258
259
260

    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
261
        Train the network
262
263
264
265
266
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
267

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
268
269
270
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)

271
        # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
272
        self.global_step = tf.Variable(0, trainable=False)
273
274
275
276
277
278
279
280
        #self.learning_rate = tf.Variable(self.base_learning_rate)
        #self.learning_rate = tf.train.exponential_decay(
        #    learning_rate=self.base_learning_rate,  # Learning rate
        #    global_step=self.global_step,
        #    decay_steps=self.decay_steps,
        #    decay_rate=self.weight_decay,  # Decay step
        #    staircase=False
        #)
281
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
282

283
        # Preparing the optimizer
284
        self.optimizer_class._learning_rate = self.learning_rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
285
286
        self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)

287
        # Train summary
288
        self.summaries_train = self.create_general_summary()
289
290

        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
291

292
293
        config = tf.ConfigProto(log_device_placement=True)
        config.gpu_options.allow_growth = True
294

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
295
        with tf.Session(config=config) as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
296
            tf.initialize_all_variables().run()
297

298
299
300
301
302
303
            # Loading a pretrained model
            if self.model_from_file != "":
                logger.info("Loading pretrained model from {0}".format(self.model_from_file))
                hdf5 = bob.io.base.HDF5File(self.model_from_file)
                self.architecture.load_variables_only(hdf5, session)

304
305
306
            if isinstance(train_data_shuffler, OnLineSampling):
                train_data_shuffler.set_feature_extractor(self.architecture, session=session)

307
            # Start a thread to enqueue data asynchronously, and hide I/O latency.
308
309
310
311
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
312

313
            # TENSOR BOARD SUMMARY
314
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
315
            for step in range(self.iterations):
316
317
318
319
320
321
322

                start = time.time()
                self.fit(session, step)
                end = time.time()
                summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
                self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
323
324
                # Running validation
                if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
325
                    self.compute_validation(session, validation_data_shuffler, step)
326

327
328
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
329
                             validation_data_shuffler, self.architecture, session), step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
330

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
331
332
333
334
335
336
337
                # Taking snapshot
                if step % self.snapshot == 0:
                    logger.info("Taking snapshot")
                    hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model_snapshot{0}.hdf5'.format(step)), 'w')
                    self.architecture.save(hdf5)
                    del hdf5

338
339
340
341
342
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
343

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
344
345
            # Saving the final network
            hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
346
347
348
            self.architecture.save(hdf5)
            del hdf5

349
350
351
352
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)