Trainer.py 16.5 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from bob.learn.tensorflow.utils.session import Session
17
from .learning_rate import constant
18

19
logger = bob.core.log.setup("bob.learn.tensorflow")
20

21

22 23 24 25 26 27
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    architecture:
      The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`

    optimizer:
      One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html

    use_gpu: bool
      Use GPUs in the training

    loss: :py:class:`bob.learn.tensorflow.loss.BaseLoss`
      Loss function

    temp_dir: str
      The output directory

    learning_rate: :py:class:`bob.learn.tensorflow.trainers.learningrate`
      Initial learning rate

    convergence_threshold:

    iterations: int
      Maximum number of iterations

    snapshot: int
      Will take a snapshot of the network at every `n` iterations

    prefetch: bool
      Use extra Threads to deal with the I/O

    model_from_file: str
      If you want to use a pretrained model

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

    verbosity_level:
64 65

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
66

67
    def __init__(self,
68 69
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
70 71
                 use_gpu=False,
                 loss=None,
72
                 temp_dir="cnn",
73

74
                 # Learning rate
75
                 learning_rate=constant(),
76

77
                 ###### training options ##########
78
                 convergence_threshold=0.01,
79
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
80 81
                 snapshot=500,
                 validation_snapshot=100,
82
                 prefetch=False,
83 84

                 ## Analizer
85
                 analizer=SoftmaxAnalizer(),
86

87 88 89
                 ### Pretrained model
                 model_from_file="",

90
                 verbosity_level=2):
91

92 93
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
94 95

        self.architecture = architecture
96
        self.optimizer_class = optimizer
97
        self.use_gpu = use_gpu
98 99 100
        self.loss = loss
        self.temp_dir = temp_dir

101
        self.learning_rate = learning_rate
102 103 104

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
105
        self.validation_snapshot = validation_snapshot
106
        self.convergence_threshold = convergence_threshold
107
        self.prefetch = prefetch
108

109 110 111 112 113 114
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
115
        self.thread_pool = None
116 117 118 119 120

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

121 122 123 124 125
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
126
        self.global_step = None
127

128
        self.model_from_file = model_from_file
129
        self.session = None
130

131 132
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
133 134 135
    def __del__(self):
        tf.reset_default_graph()

136
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
137
        """
138 139 140 141 142
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
143
            prefetch: Uses prefetch
144
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
145
            training: Is it a training graph?
146 147 148
        """

        # Defining place holders
149
        if prefetch:
150
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
151 152 153 154 155 156 157

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
158
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
159 160 161 162 163 164 165
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
166
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
167 168

        # Creating graphs and defining the loss
169
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
170 171 172 173 174 175
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
176
        Given a data shuffler prepared the dictionary to be injected in the graph
177 178 179 180

        ** Parameters **
            data_shuffler:

181
        """
182 183
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
184 185 186 187 188

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

189
    def fit(self, step):
190 191 192 193 194 195 196 197 198
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

199
        if self.prefetch:
200
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
201
                                             self.learning_rate, self.summaries_train])
202 203
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
204
            _, l, lr, summary = self.session.run([self.optimizer, self.training_graph,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
205
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
206

207 208
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
209

210
    def compute_validation(self, data_shuffler, step):
211 212 213 214 215 216 217 218 219
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
220
        # Opening a new session for validation
221
        feed_dict = self.get_feed_dict(data_shuffler)
222
        l = self.session.run(self.validation_graph, feed_dict=feed_dict)
223

224
        if self.validation_summary_writter is None:
225
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), self.session.graph)
226

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
227
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
228 229 230
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

231 232 233 234
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
235 236 237 238 239
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

240
    def start_thread(self):
241 242 243 244 245 246 247
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

248
        threads = []
249
        for n in range(3):
250
            t = threading.Thread(target=self.load_and_enqueue, args=())
251 252 253 254
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
255

256
    def load_and_enqueue(self):
257
        """
258
        Injecting data in the place holder queue
259 260 261

        **Parameters**
          session: Tensorflow session
262
        """
263

264
        while not self.thread_pool.should_stop():
265 266
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
267

268 269 270
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

271
            self.session.run(self.enqueue_op, feed_dict=feed_dict)
272

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
273
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
274
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
275
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
276
        """
277 278 279 280 281 282 283 284 285 286 287

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
288
        # Creating validation graph
289 290 291 292
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
293 294 295 296 297
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
298 299 300 301 302

         ** Parameters **
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

320
    def bootstrap_graphs_fromfile(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
321 322
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
323 324 325 326 327 328 329

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
330
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
331
        saver = self.architecture.load(self.model_from_file, clear_devices=True)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
332 333 334 335 336 337 338 339

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
340
        self.global_step = tf.get_collection("global_step")[0]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
341 342 343 344 345 346 347 348 349 350 351

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
352 353 354 355 356 357

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
358 359 360 361 362 363 364 365 366
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

367 368
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
369 370 371 372 373 374
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
375 376 377 378 379
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
380

381
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
382

383 384
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
385

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
386
        Session.create()
387 388 389 390 391
        self.session = Session.instance().session

        # Loading a pretrained model
        if self.model_from_file != "":
            logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
392
            saver = self.bootstrap_graphs_fromfile(train_data_shuffler, validation_data_shuffler)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
393 394 395

            start_step = self.global_step.eval(self.session)

396
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
397
            start_step = 0
398 399 400 401
            # Bootstraping all the graphs
            self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)

            # TODO: find an elegant way to provide this as a parameter of the trainer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
402
            self.global_step = tf.Variable(0, trainable=False, name="global_step")
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429

            # Preparing the optimizer
            self.optimizer_class._learning_rate = self.learning_rate
            self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
            tf.add_to_collection("optimizer", self.optimizer)
            tf.add_to_collection("learning_rate", self.learning_rate)

            # Train summary
            self.summaries_train = self.create_general_summary()
            tf.add_to_collection("summaries_train", self.summaries_train)

            tf.initialize_all_variables().run(session=self.session)

            # Original tensorflow saver object
            saver = tf.train.Saver(var_list=tf.all_variables())

        if isinstance(train_data_shuffler, OnLineSampling):
            train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.
        if self.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool)
            threads = self.start_thread(self.session)

        # TENSOR BOARD SUMMARY
        self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
430
        for step in range(start_step, self.iterations):
431 432

            start = time.time()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
433
            self.fit(step)
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
            end = time.time()
            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
                self.compute_validation(self.session, validation_data_shuffler, step)

                if self.analizer is not None:
                    self.validation_summary_writter.add_summary(self.analizer(
                         validation_data_shuffler, self.architecture, self.session), step)

            # Taking snapshot
            if step % self.snapshot == 0:
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.architecture.save(saver, path)

        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.architecture.save(saver, path)

        if self.prefetch:
            # now they should definetely stop
            self.thread_pool.request_stop()
            self.thread_pool.join(threads)