Trainer.py 16.4 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
from ..network import SequenceNetwork
8 9 10
import threading
import os
import bob.io.base
11
import bob.core
12
from ..analyzers import SoftmaxAnalizer
13
from tensorflow.core.framework import summary_pb2
14
import time
15
from bob.learn.tensorflow.datashuffler.OnlineSampling import OnLineSampling
16
from .learning_rate import constant
17

18
logger = bob.core.log.setup("bob.learn.tensorflow")
19

20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
      architecture: The architecture that you want to run. Should be a :py:class`bob.learn.tensorflow.network.SequenceNetwork`
      optimizer: One of the tensorflow optimizers https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html
      use_gpu: Use GPUs in the training
      loss: Loss
      temp_dir: The output directory

      base_learning_rate: Initial learning rate
      weight_decay:
      convergence_threshold:

      iterations: Maximum number of iterations
      snapshot: Will take a snapshot of the network at every `n` iterations
      prefetch: Use extra Threads to deal with the I/O
      analizer: Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`
      verbosity_level:

    """
44
    def __init__(self,
45 46
                 architecture,
                 optimizer=tf.train.AdamOptimizer(),
47 48
                 use_gpu=False,
                 loss=None,
49
                 temp_dir="cnn",
50

51
                 # Learning rate
52
                 learning_rate=constant(),
53

54
                 ###### training options ##########
55
                 convergence_threshold=0.01,
56
                 iterations=5000,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57 58
                 snapshot=500,
                 validation_snapshot=100,
59
                 prefetch=False,
60 61

                 ## Analizer
62
                 analizer=SoftmaxAnalizer(),
63

64 65 66
                 ### Pretrained model
                 model_from_file="",

67
                 verbosity_level=2):
68

69 70
        if not isinstance(architecture, SequenceNetwork):
            raise ValueError("`architecture` should be instance of `SequenceNetwork`")
71 72

        self.architecture = architecture
73
        self.optimizer_class = optimizer
74
        self.use_gpu = use_gpu
75 76 77
        self.loss = loss
        self.temp_dir = temp_dir

78
        self.learning_rate = learning_rate
79 80 81

        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
82
        self.validation_snapshot = validation_snapshot
83
        self.convergence_threshold = convergence_threshold
84
        self.prefetch = prefetch
85

86 87 88 89 90 91
        # Training variables used in the fit
        self.optimizer = None
        self.training_graph = None
        self.train_data_shuffler = None
        self.summaries_train = None
        self.train_summary_writter = None
92
        self.thread_pool = None
93 94 95 96 97

        # Validation data
        self.validation_graph = None
        self.validation_summary_writter = None

98 99 100 101 102
        # Analizer
        self.analizer = analizer

        self.thread_pool = None
        self.enqueue_op = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
        self.global_step = None
104

105 106
        self.model_from_file = model_from_file

107 108
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
109 110 111
    def __del__(self):
        tf.reset_default_graph()

112
    def compute_graph(self, data_shuffler, prefetch=False, name="", training=True):
113
        """
114 115 116 117 118
        Computes the graph for the trainer.

        ** Parameters **

            data_shuffler: Data shuffler
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
119
            prefetch: Uses prefetch
120
            name: Name of the graph
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
121
            training: Is it a training graph?
122 123 124
        """

        # Defining place holders
125
        if prefetch:
126
            [placeholder_data, placeholder_labels] = data_shuffler.get_placeholders_forprefetch(name=name)
127 128 129 130 131 132 133

            # Defining a placeholder queue for prefetching
            queue = tf.FIFOQueue(capacity=10,
                                 dtypes=[tf.float32, tf.int64],
                                 shapes=[placeholder_data.get_shape().as_list()[1:], []])

            # Fetching the place holders from the queue
134
            self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
135 136 137 138 139 140 141
            feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)

            # Creating the architecture for train and validation
            if not isinstance(self.architecture, SequenceNetwork):
                raise ValueError("The variable `architecture` must be an instance of "
                                 "`bob.learn.tensorflow.network.SequenceNetwork`")
        else:
142
            [feature_batch, label_batch] = data_shuffler.get_placeholders(name=name)
143 144

        # Creating graphs and defining the loss
145
        network_graph = self.architecture.compute_graph(feature_batch, training=training)
146 147 148 149 150 151
        graph = self.loss(network_graph, label_batch)

        return graph

    def get_feed_dict(self, data_shuffler):
        """
152
        Given a data shuffler prepared the dictionary to be injected in the graph
153 154 155 156

        ** Parameters **
            data_shuffler:

157
        """
158 159
        [data, labels] = data_shuffler.get_batch()
        [data_placeholder, label_placeholder] = data_shuffler.get_placeholders()
160 161 162 163 164

        feed_dict = {data_placeholder: data,
                     label_placeholder: labels}
        return feed_dict

165 166 167 168 169 170 171 172 173 174
    def fit(self, session, step):
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

175
        if self.prefetch:
176 177
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train])
178 179 180 181 182
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
            _, l, lr, summary = session.run([self.optimizer, self.training_graph,
                                             self.learning_rate, self.summaries_train], feed_dict=feed_dict)

183 184
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
185

186
    def compute_validation(self,  session, data_shuffler, step):
187 188 189 190 191 192 193 194 195
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
196
        # Opening a new session for validation
197 198 199
        feed_dict = self.get_feed_dict(data_shuffler)
        l = session.run(self.validation_graph, feed_dict=feed_dict)

200 201 202
        if self.validation_summary_writter is None:
            self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
203
        summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
204 205 206
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))

207 208 209 210
    def create_general_summary(self):
        """
        Creates a simple tensorboard summary with the value of the loss and learning rate
        """
211 212 213 214 215
        # Train summary
        tf.scalar_summary('loss', self.training_graph, name="train")
        tf.scalar_summary('lr', self.learning_rate, name="train")
        return tf.merge_all_summaries()

216
    def start_thread(self, session):
217 218 219 220 221 222 223
        """
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
        """

224
        threads = []
225 226
        for n in range(3):
            t = threading.Thread(target=self.load_and_enqueue, args=(session,))
227 228 229 230
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
231

232 233
    def load_and_enqueue(self, session):
        """
234
        Injecting data in the place holder queue
235 236 237

        **Parameters**
          session: Tensorflow session
238
        """
239

240
        while not self.thread_pool.should_stop():
241 242
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
            [train_placeholder_data, train_placeholder_labels] = self.train_data_shuffler.get_placeholders()
243

244 245 246
            feed_dict = {train_placeholder_data: train_data,
                         train_placeholder_labels: train_labels}

247
            session.run(self.enqueue_op, feed_dict=feed_dict)
248

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
249
    def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
250
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
251
        Create all the necessary graphs for training, validation and inference graphs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
252
        """
253 254 255 256 257 258 259 260 261 262 263

        # Creating train graph
        self.training_graph = self.compute_graph(train_data_shuffler, prefetch=self.prefetch, name="train")
        tf.add_to_collection("training_graph", self.training_graph)

        # Creating inference graph
        self.architecture.compute_inference_placeholder(train_data_shuffler.deployment_shape)
        self.architecture.compute_inference_graph()
        tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
        tf.add_to_collection("inference_graph", self.architecture.inference_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
264
        # Creating validation graph
265 266 267 268
        if validation_data_shuffler is not None:
            self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
            tf.add_to_collection("validation_graph", self.validation_graph)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
269 270 271 272 273
        self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)

    def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
        """
        Persist the placeholders
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
274 275 276 277 278

         ** Parameters **
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
        """

        # Persisting the placeholders
        if self.prefetch:
            batch, label = train_data_shuffler.get_placeholders_forprefetch()
        else:
            batch, label = train_data_shuffler.get_placeholders()

        tf.add_to_collection("train_placeholder_data", batch)
        tf.add_to_collection("train_placeholder_label", label)

        # Creating validation graph
        if validation_data_shuffler is not None:
            batch, label = validation_data_shuffler.get_placeholders()
            tf.add_to_collection("validation_placeholder_data", batch)
            tf.add_to_collection("validation_placeholder_label", label)

    def bootstrap_graphs_fromfile(self, session, train_data_shuffler, validation_data_shuffler):
        """
        Bootstrap all the necessary data from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
299 300 301 302 303 304 305

         ** Parameters **
           session: Tensorflow session
           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        """
        saver = self.architecture.load(session, self.model_from_file)

        # Loading training graph
        self.training_graph = tf.get_collection("training_graph")[0]

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
        self.summaries_train = tf.get_collection("summaries_train")[0]

        if validation_data_shuffler is not None:
            self.validation_graph = tf.get_collection("validation_graph")[0]

        self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)

        return saver

    def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
        """
        Load placeholders from file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
327 328 329 330 331 332

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
333 334 335 336 337 338 339 340 341
        """

        train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
                                             tf.get_collection("train_placeholder_label")[0])

        if validation_data_shuffler is not None:
            train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
                                                 tf.get_collection("validation_placeholder_label")[0])

342 343
    def train(self, train_data_shuffler, validation_data_shuffler=None):
        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
344 345 346 347 348 349
        Train the network:

         ** Parameters **

           train_data_shuffler: Data shuffler for training
           validation_data_shuffler: Data shuffler for validation
350 351 352 353 354
        """

        # Creating directory
        bob.io.base.create_directories_safe(self.temp_dir)
        self.train_data_shuffler = train_data_shuffler
355

356
        logger.info("Initializing !!")
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
357

358 359
        config = tf.ConfigProto(log_device_placement=True,
                                gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.333))
360
        config.gpu_options.allow_growth = True
361

362 363
        # Pickle the architecture to save
        self.architecture.pickle_net(train_data_shuffler.deployment_shape)
364

365
        with tf.Session(config=config) as session:
366

367 368 369
            # Loading a pretrained model
            if self.model_from_file != "":
                logger.info("Loading pretrained model from {0}".format(self.model_from_file))
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
370
                saver = self.bootstrap_graphs_fromfile(session, train_data_shuffler, validation_data_shuffler)
371
            else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
372 373
                # Bootstraping all the graphs
                self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
374 375 376 377 378 379 380 381 382 383 384 385

                # TODO: find an elegant way to provide this as a parameter of the trainer
                self.global_step = tf.Variable(0, trainable=False)

                # Preparing the optimizer
                self.optimizer_class._learning_rate = self.learning_rate
                self.optimizer = self.optimizer_class.minimize(self.training_graph, global_step=self.global_step)
                tf.add_to_collection("optimizer", self.optimizer)
                tf.add_to_collection("learning_rate", self.learning_rate)

                # Train summary
                self.summaries_train = self.create_general_summary()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
386
                tf.add_to_collection("summaries_train", self.summaries_train)
387 388 389 390 391

                tf.initialize_all_variables().run()

                # Original tensorflow saver object
                saver = tf.train.Saver(var_list=tf.all_variables())
392

393 394 395
            if isinstance(train_data_shuffler, OnLineSampling):
                train_data_shuffler.set_feature_extractor(self.architecture, session=session)

396
            # Start a thread to enqueue data asynchronously, and hide I/O latency.
397 398 399 400
            if self.prefetch:
                self.thread_pool = tf.train.Coordinator()
                tf.train.start_queue_runners(coord=self.thread_pool)
                threads = self.start_thread(session)
401

402
            # TENSOR BOARD SUMMARY
403
            self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
404
            for step in range(self.iterations):
405 406 407 408 409 410 411

                start = time.time()
                self.fit(session, step)
                end = time.time()
                summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
                self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
412 413
                # Running validation
                if validation_data_shuffler is not None and step % self.validation_snapshot == 0:
414
                    self.compute_validation(session, validation_data_shuffler, step)
415

416 417
                    if self.analizer is not None:
                        self.validation_summary_writter.add_summary(self.analizer(
418
                             validation_data_shuffler, self.architecture, session), step)
419

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
420 421 422
                # Taking snapshot
                if step % self.snapshot == 0:
                    logger.info("Taking snapshot")
423 424
                    path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                    self.architecture.save(session, saver, path)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
425

426 427 428 429 430
            logger.info("Training finally finished")

            self.train_summary_writter.close()
            if validation_data_shuffler is not None:
                self.validation_summary_writter.close()
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
431

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
432
            # Saving the final network
433 434
            path = os.path.join(self.temp_dir, 'model.ckp')
            self.architecture.save(session, saver, path)
435

436 437 438 439
            if self.prefetch:
                # now they should definetely stop
                self.thread_pool.request_stop()
                self.thread_pool.join(threads)