Trainer.py 16.7 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler import OnlineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19 20 21 22 23
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

24

25 26 27 28 29 30
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
31

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    architecture:
      The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`

    optimizer:
      One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html

    use_gpu: bool
      Use GPUs in the training

    loss: :py:class:`bob.learn.tensorflow.loss.BaseLoss`
      Loss function

    temp_dir: str
      The output directory

47
    learning_rate: `bob.learn.tensorflow.trainers.learning_rate`
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
      Initial learning rate

    convergence_threshold:

    iterations: int
      Maximum number of iterations

    snapshot: int
      Will take a snapshot of the network at every `n` iterations

    prefetch: bool
      Use extra Threads to deal with the I/O

    model_from_file: str
      If you want to use a pretrained model

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

    verbosity_level:
68 69

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70

71
    def __init__(self,
72 73
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
74 75
                 use_gpu=False,
                 loss=None,
76
                 temp_dir="cnn",
77

78
                 # Learning rate
79
                 learning_rate=None,
80

81
                 ###### training options ##########
82
                 convergence_threshold=0.01,
83
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
84 85
                 snapshot=500,
                 validation_snapshot=100,
86
                 prefetch=False,
87 88

                 ## Analizer
89
                 analizer=SoftmaxAnalizer(),
90

91 92 93
                 ### Pretrained model
                 model_from_file="",

94
                 verbosity_level=2):
95

96 97
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
98 99

        self.architecture = architecture
100
        self.optimizer_class = optimizer
101
        self.use_gpu = use_gpu
102 103 104
        self.loss = loss
        self.temp_dir = temp_dir

105 106 107 108
        if learning_rate is None and model_from_file == "":
            self.learning_rate = constant()
        else:
            self.learning_rate = learning_rate
109 110 111

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
112
        self.validation_snapshot = validation_snapshot
113
        self.convergence_threshold = convergence_threshold
114
        self.prefetch = prefetch
115

116 117 118 119 120 121
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
122
        self.thread_pool = None
123 124 125 126 127

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

128 129 130 131 132
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
133
        self.global_step = None
134

135
        self.model_from_file = model_from_file
136
        self.session = None
137

138 139
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
140 141 142
    def __del__(self):
        tf.reset_default_graph()

143
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
144
        """
145 146 147 148 149
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
150
            prefetch: Uses prefetch
151
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
152
            training: Is it a training graph?
153 154 155
        """

        # Defining place holders
156
        if prefetch:
157
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
158 159 160 161 162 163 164

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
165
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
166 167 168 169 170 171 172
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
173
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
174 175

        # Creating graphs and defining the loss
176
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
177 178 179 180 181 182
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
183
        Given a data shuffler prepared the dictionary to be injected in the graph
184 185 186 187

        ** Parameters **
            data_shuffler:

188
        """
189 190
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
191 192 193 194 195

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

196
    def fit(self, step):
197 198 199 200 201 202 203 204 205
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

206
        if self.prefetch:
207
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
208
                                                  self.learning_rate, self.summaries_train])
209 210
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
211
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
212
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
213

214 215
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
216

217
    def compute_validation(self, data_shuffler, step):
218 219 220 221 222 223 224 225 226
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
227
        # Opening a new session for validation
228
        feed_dict = self.get_feed_dict(data_shuffler)
229
        l = self.session.run(self.validation_graph, feed_dict=feed_dict)
230

231
        if self.validation_summary_writter is None:
232
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph)
233

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
234
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
235 236 237
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

238 239 240 241
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
242
        # Train summary
243 244 245
        tf.summary.scalar('loss', self.training_graph)
        tf.summary.scalar('lr', self.learning_rate)
        return tf.summary.merge_all()
246

247
    def start_thread(self):
248 249 250 251 252 253 254
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

255
        threads = []
256
        for n in range(3):
257
            t = threading.Thread(target=self.load_and_enqueue, args=())
258 259 260 261
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
262

263
    def load_and_enqueue(self):
264
        """
265
        Injecting data in the place holder queue
266 267 268

        **Parameters**
          session: Tensorflow session
269
        """
270

271
        while not self.thread_pool.should_stop():
272 273
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
274

275 276 277
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

278
            self.session.run(self.enqueue_op, feed_dict=feed_dict)
279

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
280
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
281
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
282
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
283
        """
284 285 286 287 288 289 290 291 292 293
        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
294
        # Creating validation graph
295 296 297 298
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
299 300 301 302 303
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
304 305 306 307 308

         ** Parameters **
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

326
    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
327 328
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
329 330 331 332 333 334 335

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
336
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
337
        saver = self.architecture.load(self.model_from_file, clear_devices=False)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
338 339 340 341 342 343 344 345

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
346
        self.global_step = tf.get_collection("global_step")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
347 348 349 350 351 352 353 354 355 356 357

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
358 359 360 361 362 363

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
364 365 366 367 368 369 370 371 372
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

373 374
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
375 376 377 378 379 380
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
381 382 383 384 385
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
386

387
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
388

389 390
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
391

392
        self.session = Session.instance(new=True).session
393 394 395 396

        # Loading a pretrained model
        if self.model_from_file != "":
            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
397
            saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
398

399
            start_step = self.global_step.eval(session=self.session)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
400

401
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
402
            start_step = 0
403 404 405 406
            # Bootstraping all the graphs
            self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)

            # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
407
            self.global_step = tf.Variable(0, trainable=False, name="global_step")
408
            tf.add_to_collection("global_step", self.global_step)
409 410 411 412 413 414 415 416 417 418 419

            # Preparing the optimizer
            self.optimizer_class._learning_rate = self.learning_rate
            self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
            tf.add_to_collection("optimizer", self.optimizer)
            tf.add_to_collection("learning_rate", self.learning_rate)

            # Train summary
            self.summaries_train = self.create_general_summary()
            tf.add_to_collection("summaries_train", self.summaries_train)

420 421
            tf.add_to_collection("summaries_train", self.summaries_train)

422
            tf.global_variables_initializer().run(session=self.session)
423 424

            # Original tensorflow saver object
425
            saver = tf.train.Saver(var_list=tf.global_variables())
426

427
        if isinstance(train_data_shuffler, OnlineSampling):
428 429 430 431 432
            train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
        if self.prefetch:
            self.thread_pool = tf.train.Coordinator()
433 434
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            threads = self.start_thread()
435 436

        # TENSOR BOARD SUMMARY
437
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
438
        for step in range(start_step, self.iterations):
439
            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
440
            self.fit(step)
441 442 443 444 445 446
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
447
                self.compute_validation(validation_data_shuffler, step)
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472

                if self.analizer is not None:
                    self.validation_summary_writter.add_summary(self.analizer(
                         validation_data_shuffler, self.architecture, self.session), step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.architecture.save(saver, path)

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.architecture.save(saver, path)

        if self.prefetch:
            # now they should definetely stop
            self.thread_pool.request_stop()
            self.thread_pool.join(threads)