Trainer.py 16.8 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler import OnlineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27 28 29 30
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    architecture:
      The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`

    optimizer:
      One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html

    use_gpu: bool
      Use GPUs in the training

    loss: :py:class:`bob.learn.tensorflow.loss.BaseLoss`
      Loss function

    temp_dir: str
      The output directory

47
    learning_rate: `bob.learn.tensorflow.trainers.learning_rate`
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
      Initial learning rate

    convergence_threshold:

    iterations: int
      Maximum number of iterations

    snapshot: int
      Will take a snapshot of the network at every `n` iterations

    prefetch: bool
      Use extra Threads to deal with the I/O

    model_from_file: str
      If you want to use a pretrained model

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

    verbosity_level:
68 69

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70

71
    def __init__(self,
72 73
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
74 75
                 use_gpu=False,
                 loss=None,
76
                 temp_dir="cnn",
77

78
                 # Learning rate
79
                 learning_rate=None,
80

81
                 ###### training options ##########
82
                 convergence_threshold=0.01,
83
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
84 85
                 snapshot=500,
                 validation_snapshot=100,
86
                 prefetch=False,
87 88

                 ## Analizer
89
                 analizer=SoftmaxAnalizer(),
90

91 92 93
                 ### Pretrained model
                 model_from_file="",

94
                 verbosity_level=2):
95

96 97
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
98 99

        self.architecture = architecture
100
        self.optimizer_class = optimizer
101
        self.use_gpu = use_gpu
102 103 104
        self.loss = loss
        self.temp_dir = temp_dir

105 106 107 108
        if learning_rate is None and model_from_file == "":
            self.learning_rate = constant()
        else:
            self.learning_rate = learning_rate
109 110 111

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
112
        self.validation_snapshot = validation_snapshot
113
        self.convergence_threshold = convergence_threshold
114
        self.prefetch = prefetch
115

116 117 118 119 120 121
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
122
        self.thread_pool = None
123 124 125 126 127

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

128 129 130 131 132
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
133
        self.global_step = None
134

135
        self.model_from_file = model_from_file
136
        self.session = None
137

138 139
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
140 141 142
    def __del__(self):
        tf.reset_default_graph()

143
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
144
        """
145 146 147 148 149
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
150
            prefetch: Uses prefetch
151
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
152
            training: Is it a training graph?
153 154 155
        """

        # Defining place holders
156
        if prefetch:
157
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
158 159 160 161 162 163 164

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
165
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
166 167 168 169 170 171 172
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
173
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
174 175

        # Creating graphs and defining the loss
176
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
177 178 179 180 181 182
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
183
        Given a data shuffler prepared the dictionary to be injected in the graph
184 185 186 187

        ** Parameters **
            data_shuffler:

188
        """
189 190
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
191 192 193 194 195

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

196
    def fit(self, step):
197 198 199 200 201 202 203 204 205
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

206
        if self.prefetch:
207
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
208
                                                  self.learning_rate, self.summaries_train])
209 210
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
211
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
212
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
213

214 215
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
216

217
    def compute_validation(self, data_shuffler, step):
218 219 220 221 222 223 224 225 226
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
227
        # Opening a new session for validation
228
        feed_dict = self.get_feed_dict(data_shuffler)
229
        l = self.session.run(self.validation_graph, feed_dict=feed_dict)
230

231
        if self.validation_summary_writter is None:
232
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph)
233

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
234
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
235 236 237
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

238 239 240 241
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
242 243 244 245 246
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

247
    def start_thread(self):
248 249 250 251 252 253 254
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

255
        threads = []
256
        for n in range(3):
257
            t = threading.Thread(target=self.load_and_enqueue, args=())
258 259 260 261
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
262

263
    def load_and_enqueue(self):
264
        """
265
        Injecting data in the place holder queue
266 267 268

        **Parameters**
          session: Tensorflow session
269
        """
270

271
        while not self.thread_pool.should_stop():
272 273
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
274

275 276 277
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

278
            self.session.run(self.enqueue_op, feed_dict=feed_dict)
279

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
280
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
281
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
282
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
283
        """
284 285 286 287 288 289 290 291 292 293 294

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
295
        # Creating validation graph
296 297 298 299
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
300 301 302 303 304
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
305 306 307 308 309

         ** Parameters **
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

327
    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
328 329
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
330 331 332 333 334 335 336

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
337
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
338
        saver = self.architecture.load(self.model_from_file, clear_devices=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
339 340 341 342 343 344 345 346

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
347
        self.global_step = tf.get_collection("global_step")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
348 349 350 351 352 353 354 355 356 357 358

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
359 360 361 362 363 364

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
365 366 367 368 369 370 371 372 373
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

374 375
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
376 377 378 379 380 381
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
382 383 384 385 386
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
387

388
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
389

390 391
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
392

393
        self.session = Session.instance(new=True).session
394 395 396 397

        # Loading a pretrained model
        if self.model_from_file != "":
            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
398
            saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
399

400
            start_step = self.global_step.eval(session=self.session)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
401

402
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
403
            start_step = 0
404 405 406 407
            # Bootstraping all the graphs
            self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)

            # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
408
            self.global_step = tf.Variable(0, trainable=False, name="global_step")
409
            tf.add_to_collection("global_step", self.global_step)
410 411 412 413 414 415 416 417 418 419 420

            # Preparing the optimizer
            self.optimizer_class._learning_rate = self.learning_rate
            self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
            tf.add_to_collection("optimizer", self.optimizer)
            tf.add_to_collection("learning_rate", self.learning_rate)

            # Train summary
            self.summaries_train = self.create_general_summary()
            tf.add_to_collection("summaries_train", self.summaries_train)

421 422
            tf.add_to_collection("summaries_train", self.summaries_train)

423 424 425 426 427
            tf.initialize_all_variables().run(session=self.session)

            # Original tensorflow saver object
            saver = tf.train.Saver(var_list=tf.all_variables())

428
        if isinstance(train_data_shuffler, OnlineSampling):
429 430 431 432 433
            train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
        if self.prefetch:
            self.thread_pool = tf.train.Coordinator()
434 435
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
436 437 438

        # TENSOR BOARD SUMMARY
        self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
439
        for step in range(start_step, self.iterations):
440 441

            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
442
            self.fit(step)
443 444 445 446 447 448
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
449
                self.compute_validation(validation_data_shuffler, step)
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474

                if self.analizer is not None:
                    self.validation_summary_writter.add_summary(self.analizer(
                         validation_data_shuffler, self.architecture, self.session), step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.architecture.save(saver, path)

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.architecture.save(saver, path)

        if self.prefetch:
            # now they should definetely stop
            self.thread_pool.request_stop()
            self.thread_pool.join(threads)