Trainer.py 11.8 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8
9
10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
16
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling

17
18
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,0"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
19

20
logger = bob.core.log.setup("bob.learn.tensorflow")
21

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
46
    def __init__(self,
47
48
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
49
50
                 use_gpu=False,
                 loss=None,
51
                 temp_dir="cnn",
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
52

53
                 # Learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54
                 base_learning_rate=0.1,
55
                 weight_decay=0.9,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
56
                 decay_steps=1000,
57

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
58
                 ###### training options ##########
59
                 convergence_threshold=0.01,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
60
                 iterations=5000,
61
62
                 snapshot=100,
                 prefetch=False,
63
64

                 ## Analizer
65
                 analizer=SoftmaxAnalizer(),
66

67
                 verbosity_level=2):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
68

69
70
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
71
72

        self.architecture = architecture
73
        self.optimizer_class = optimizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
74
        self.use_gpu = use_gpu
75
76
77
78
79
        self.loss = loss
        self.temp_dir = temp_dir

        self.base_learning_rate = base_learning_rate
        self.weight_decay = weight_decay
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
80
        self.decay_steps = decay_steps
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81
82
83
84

        self.iterations = iterations
        self.snapshot = snapshot
        self.convergence_threshold = convergence_threshold
85
        self.prefetch = prefetch
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
86

87
88
89
90
91
92
93
94
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.learning_rate = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
95
        self.thread_pool = None
96
97
98
99
100

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

101
102
103
104
105
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
106
        self.global_step = None
107

108
109
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
110
111
112
    def __del__(self):
        tf.reset_default_graph()

113
    def compute_graph(self, data_shuffler, prefetch=False, name=""):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
114
        """
115
116
        Computes the graph for the trainer.

117

118
119
120
        ** Parameters **

            data_shuffler: Data shuffler
121
            prefetch:
122
123
124
125
            name: Name of the graph
        """

        # Defining place holders
126
        if prefetch:
127
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
128
129
130
131
132
133
134

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
135
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
136
137
138
139
140
141
142
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
143
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
144
145
146
147
148
149
150
151
152

        # Creating graphs and defining the loss
        network_graph = self.architecture.compute_graph(feature_batch)
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
153
        Given a data shuffler prepared the dictionary to be injected in the graph
154
155
156
157

        ** Parameters **
            data_shuffler:

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
158
        """
159
160
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
161
162
163
164
165

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

166
167
168
169
170
171
172
173
174
175
    def fit(self, session, step):
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

176
        if self.prefetch:
177
178
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
179
180
181
182
183
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

184
185
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
186

187
    def compute_validation(self,  session, data_shuffler, step):
188
189
190
191
192
193
194
195
196
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
197
        # Opening a new session for validation
198
199
200
201
        self.validation_graph = self.compute_graph(data_shuffler, name="validation")
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

202
203
204
        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

205
206
207
208
209
        summaries = []
        summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

210
211
212
213
214
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """

215
216
217
218
219
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

220
    def start_thread(self, session):
221
222
223
224
225
226
227
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

228
        threads = []
229
230
        for n in range(3):
            t = threading.Thread(target=self.load_and_enqueue, args=(session,))
231
232
233
234
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
235

236
237
    def load_and_enqueue(self, session):
        """
238
        Injecting data in the place holder queue
239
240
241

        **Parameters**
          session: Tensorflow session
242
        """
243

244
        while not self.thread_pool.should_stop():
245
246
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
247

248
249
250
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

251
            session.run(self.enqueue_op, feed_dict=feed_dict)
252
253
254

    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
255
        Train the network
256
257
258
259
260
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
261

262
        # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
263
        self.global_step = tf.Variable(0, trainable=False)
264
        self.learning_rate = tf.train.exponential_decay(
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
265
266
267
268
269
            learning_rate=self.base_learning_rate,  # Learning rate
            global_step=self.global_step,
            decay_steps=self.decay_steps,
            decay_rate=self.weight_decay,  # Decay step
            staircase=False
270
        )
271
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
272

273
        # Preparing the optimizer
274
        self.optimizer_class._learning_rate = self.learning_rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
275
276
        self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
277

278
        # Train summary
279
        self.summaries_train = self.create_general_summary()
280
281

        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
282
        # Training
283
        hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
284

285
286
        config = tf.ConfigProto(log_device_placement=True)
        config.gpu_options.allow_growth = True
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
287
        with tf.Session(config=config) as session:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
288

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
289
            tf.initialize_all_variables().run()
290

291
292
293
            if isinstance(train_data_shuffler, OnLineSampling):
                train_data_shuffler.set_feature_extractor(self.architecture, session=session)

294
            # Start a thread to enqueue data asynchronously, and hide I/O latency.
295
296
297
298
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
299

300
            # TENSOR BOARD SUMMARY
301
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
302
            for step in range(self.iterations):
303
304
305
306
307
308
309

                start = time.time()
                self.fit(session, step)
                end = time.time()
                summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
                self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

310
                if validation_data_shuffler is not None and step % self.snapshot == 0:
311
                    self.compute_validation(session, validation_data_shuffler, step)
312

313
314
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
315
                             validation_data_shuffler, self.architecture, session), step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
316

317
318
319
320
321
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
322

323
324
325
            self.architecture.save(hdf5)
            del hdf5

326
327
328
329
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)
330

331
            session.close() # For some reason the session is not closed after the context manager finishes