Trainer.py 16.7 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler import OnlineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19
logger = bob.core.log.setup("bob.learn.tensorflow")
20

21

22 23 24 25 26 27
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
28

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    architecture:
      The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`

    optimizer:
      One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html

    use_gpu: bool
      Use GPUs in the training

    loss: :py:class:`bob.learn.tensorflow.loss.BaseLoss`
      Loss function

    temp_dir: str
      The output directory

44
    learning_rate: `bob.learn.tensorflow.trainers.learning_rate`
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
      Initial learning rate

    convergence_threshold:

    iterations: int
      Maximum number of iterations

    snapshot: int
      Will take a snapshot of the network at every `n` iterations

    prefetch: bool
      Use extra Threads to deal with the I/O

    model_from_file: str
      If you want to use a pretrained model

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

    verbosity_level:
65 66

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67

68
    def __init__(self,
69 70
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
71 72
                 use_gpu=False,
                 loss=None,
73
                 temp_dir="cnn",
74

75
                 # Learning rate
76
                 learning_rate=None,
77

78
                 ###### training options ##########
79
                 convergence_threshold=0.01,
80
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
81 82
                 snapshot=500,
                 validation_snapshot=100,
83
                 prefetch=False,
84 85

                 ## Analizer
86
                 analizer=SoftmaxAnalizer(),
87

88 89 90
                 ### Pretrained model
                 model_from_file="",

91
                 verbosity_level=2):
92

93 94
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
95 96

        self.architecture = architecture
97
        self.optimizer_class = optimizer
98
        self.use_gpu = use_gpu
99 100 101
        self.loss = loss
        self.temp_dir = temp_dir

102 103 104 105
        if learning_rate is None and model_from_file == "":
            self.learning_rate = constant()
        else:
            self.learning_rate = learning_rate
106 107 108

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
109
        self.validation_snapshot = validation_snapshot
110
        self.convergence_threshold = convergence_threshold
111
        self.prefetch = prefetch
112

113 114 115 116 117 118
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
119
        self.thread_pool = None
120 121 122 123 124

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

125 126 127 128 129
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
130
        self.global_step = None
131

132
        self.model_from_file = model_from_file
133
        self.session = None
134

135 136
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
137 138 139
    def __del__(self):
        tf.reset_default_graph()

140
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
141
        """
142 143 144 145 146
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
147
            prefetch: Uses prefetch
148
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
149
            training: Is it a training graph?
150 151 152
        """

        # Defining place holders
153
        if prefetch:
154
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
155 156 157 158 159 160 161

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
162
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
163 164 165 166 167 168 169
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
170
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
171 172

        # Creating graphs and defining the loss
173
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
174 175 176 177 178 179
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
180
        Given a data shuffler prepared the dictionary to be injected in the graph
181 182 183 184

        ** Parameters **
            data_shuffler:

185
        """
186 187
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
188 189 190 191 192

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

193
    def fit(self, step):
194 195 196 197 198 199 200 201 202
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

203
        if self.prefetch:
204
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
205
                                             self.learning_rate, self.summaries_train])
206 207
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
208
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
209
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
210

211 212
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
213

214
    def compute_validation(self, data_shuffler, step):
215 216 217 218 219 220 221 222 223
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
224
        # Opening a new session for validation
225
        feed_dict = self.get_feed_dict(data_shuffler)
226
        l = self.session.run(self.validation_graph, feed_dict=feed_dict)
227

228
        if self.validation_summary_writter is None:
229
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph)
230

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
231
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
232 233 234
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

235 236 237 238
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
239 240 241 242 243
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

244
    def start_thread(self):
245 246 247 248 249 250 251
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

252
        threads = []
253
        for n in range(3):
254
            t = threading.Thread(target=self.load_and_enqueue, args=())
255 256 257 258
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
259

260
    def load_and_enqueue(self):
261
        """
262
        Injecting data in the place holder queue
263 264 265

        **Parameters**
          session: Tensorflow session
266
        """
267

268
        while not self.thread_pool.should_stop():
269 270
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
271

272 273 274
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

275
            self.session.run(self.enqueue_op, feed_dict=feed_dict)
276

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
277
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
278
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
279
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
280
        """
281 282 283 284 285 286 287 288 289 290 291

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
292
        # Creating validation graph
293 294 295 296
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
297 298 299 300 301
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
302 303 304 305 306

         ** Parameters **
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

324
    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
325 326
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
327 328 329 330 331 332 333

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
334
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
335
        saver = self.architecture.load(self.model_from_file, clear_devices=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
336 337 338 339 340 341 342 343

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
344
        self.global_step = tf.get_collection("global_step")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
345 346 347 348 349 350 351 352 353 354 355

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
356 357 358 359 360 361

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
362 363 364 365 366 367 368 369 370
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

371 372
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
373 374 375 376 377 378
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
379 380 381 382 383
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
384

385
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
386

387 388
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
389

390
        self.session = Session.instance(new=True).session
391 392 393 394

        # Loading a pretrained model
        if self.model_from_file != "":
            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
395
            saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
396

397
            start_step = self.global_step.eval(session=self.session)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
398

399
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
400
            start_step = 0
401 402 403 404
            # Bootstraping all the graphs
            self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)

            # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
405
            self.global_step = tf.Variable(0, trainable=False, name="global_step")
406
            tf.add_to_collection("global_step", self.global_step)
407 408 409 410 411 412 413 414 415 416 417

            # Preparing the optimizer
            self.optimizer_class._learning_rate = self.learning_rate
            self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
            tf.add_to_collection("optimizer", self.optimizer)
            tf.add_to_collection("learning_rate", self.learning_rate)

            # Train summary
            self.summaries_train = self.create_general_summary()
            tf.add_to_collection("summaries_train", self.summaries_train)

418 419
            tf.add_to_collection("summaries_train", self.summaries_train)

420 421 422 423 424
            tf.initialize_all_variables().run(session=self.session)

            # Original tensorflow saver object
            saver = tf.train.Saver(var_list=tf.all_variables())

425
        if isinstance(train_data_shuffler, OnlineSampling):
426 427 428 429 430
            train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
        if self.prefetch:
            self.thread_pool = tf.train.Coordinator()
431 432
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
433 434 435

        # TENSOR BOARD SUMMARY
        self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
436
        for step in range(start_step, self.iterations):
437 438

            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
439
            self.fit(step)
440 441 442 443 444 445
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
446
                self.compute_validation(validation_data_shuffler, step)
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471

                if self.analizer is not None:
                    self.validation_summary_writter.add_summary(self.analizer(
                         validation_data_shuffler, self.architecture, self.session), step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.architecture.save(saver, path)

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.architecture.save(saver, path)

        if self.prefetch:
            # now they should definetely stop
            self.thread_pool.request_stop()
            self.thread_pool.join(threads)