Trainer.py 19.1 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling, TFRecord
15
from bob.learn.tensorflow.utils.session import Session
16
from bob.learn.tensorflow.utils import compute_embedding_accuracy
17
from .learning_rate import constant
18
import time
19

20 21 22 23 24
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

25

26 27 28 29 30 31
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
32

Tiago Pereira's avatar
Tiago Pereira committed
33 34
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35

Tiago Pereira's avatar
Tiago Pereira committed
36
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
37
      Maximum number of iterations
38

Tiago Pereira's avatar
Tiago Pereira committed
39
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
40
      Will take a snapshot of the network at every `n` iterations
41

Tiago Pereira's avatar
Tiago Pereira committed
42 43
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44 45 46 47

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
48 49 50
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
51
    verbosity_level:
52 53

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54

55
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
56
                 train_data_shuffler,
57
                 validation_data_shuffler=None,
58
                 validate_with_embeddings=False,
59

60 61
                 ###### training options ##########
                 iterations=5000,
62
                 snapshot=1000,
63
                 validation_snapshot=2000,#2000,
64
                 keep_checkpoint_every_n_hours=2,
65 66

                 ## Analizer
67
                 analizer=SoftmaxAnalizer(),
68

Tiago Pereira's avatar
Tiago Pereira committed
69 70
                 # Temporatu dir
                 temp_dir="cnn",
71

72
                 verbosity_level=2):
73

Tiago Pereira's avatar
Tiago Pereira committed
74
        self.train_data_shuffler = train_data_shuffler
75

76 77
        self.temp_dir = temp_dir

78 79
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
80
        self.validation_snapshot = validation_snapshot
81
        self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
82

83 84 85
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
86
        self.thread_pool = None
87
        self.centers = None
88 89 90

        # Validation data
        self.validation_summary_writter = None
91
        self.summaries_validation = None
92
        self.validation_data_shuffler = validation_data_shuffler
93

94 95
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
96
        self.global_step = None
97

98
        self.session = None
99

Tiago Pereira's avatar
Tiago Pereira committed
100
        self.graph = None
101
        self.validation_graph = None
102
        self.prelogits = None
103

Tiago Pereira's avatar
Tiago Pereira committed
104
        self.loss = None
105 106 107 108
        self.validation_loss = None

        self.validate_with_embeddings = validate_with_embeddings

Tiago Pereira's avatar
Tiago Pereira committed
109 110
        self.optimizer_class = None
        self.learning_rate = None
111

Tiago Pereira's avatar
Tiago Pereira committed
112 113
        # Training variables used in the fit
        self.optimizer = None
114

Tiago Pereira's avatar
Tiago Pereira committed
115 116
        self.data_ph = None
        self.label_ph = None
117

118 119
        self.validation_data_ph = None
        self.validation_label_ph = None
120

Tiago Pereira's avatar
Tiago Pereira committed
121 122
        self.saver = None

123 124
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
125 126 127
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
128

129 130
    def train(self):
        """
131
        Train the network
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        Here we basically have the loop for that takes your graph and do a sequence of session.run
        """

        # Creating directories
        bob.io.base.create_directories_safe(self.temp_dir)
        logger.info("Initializing !!")

        # Loading a pretrained model
        if self.from_scratch:
            start_step = 0
        else:
            start_step = self.global_step.eval(session=self.session)

        # TODO: Put this back as soon as possible
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

149
        # Start a thread to enqueue data asynchronously, and hide I/O latency.
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            # In case you have your own queue
            if not isinstance(self.train_data_shuffler, TFRecord):
                threads = self.start_thread()

        # Bootstrapping the summary writters
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)

        ######################### Loop for #################
        for step in range(start_step, start_step+self.iterations):
            # Run fit in the graph
            start = time.time()
            self.fit(step)
            end = time.time()

            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if self.validation_data_shuffler is not None and step % self.validation_snapshot == 0:
175 176 177 178
                if self.validate_with_embeddings:
                    self.compute_validation_embeddings(step)
                else:
                    self.compute_validation(step)
179 180

            # Taking snapshot
181
            if step % self.snapshot == 0:
182 183 184 185 186 187
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.saver.save(self.session, path, global_step=step)

        # Running validation for the last time
        if self.validation_data_shuffler is not None:
188 189 190 191
            if self.validate_with_embeddings:
                self.compute_validation_embeddings(step)
            else:
                self.compute_validation(step)
192 193


194 195 196 197 198 199 200 201 202 203 204 205 206 207
        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.saver.save(self.session, path)

        if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
            # now they should definetely stop
            self.thread_pool.request_stop()
            #if not isinstance(self.train_data_shuffler, TFRecord):
208
            #    self.thread_pool.join(threads)
209

Tiago Pereira's avatar
Tiago Pereira committed
210 211
    def create_network_from_scratch(self,
                                    graph,
212
                                    validation_graph=None,
Tiago Pereira's avatar
Tiago Pereira committed
213 214
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
215
                                    validation_loss=None,
216

Tiago Pereira's avatar
Tiago Pereira committed
217 218
                                    # Learning rate
                                    learning_rate=None,
219
                                    prelogits=None
Tiago Pereira's avatar
Tiago Pereira committed
220 221
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
222 223
        """
        Prepare all the tensorflow variables before training.
224

Tiago Pereira's avatar
Tiago Pereira committed
225
        **Parameters**
226

Tiago Pereira's avatar
Tiago Pereira committed
227
            graph: Input graph for training
228

Tiago Pereira's avatar
Tiago Pereira committed
229
            optimizer: Solver
230

Tiago Pereira's avatar
Tiago Pereira committed
231
            loss: Loss function
232

Tiago Pereira's avatar
Tiago Pereira committed
233 234
            learning_rate: Learning rate
        """
235
        # Getting the pointer to the placeholders
236 237
        self.data_ph = self.train_data_shuffler("data", from_queue=True)
        self.label_ph = self.train_data_shuffler("label", from_queue=True)
238

Tiago Pereira's avatar
Tiago Pereira committed
239
        self.graph = graph
240
        self.loss = loss
241

242 243 244
        # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
        self.centers = None
        if prelogits is not None:
245 246
            self.loss = loss['loss']
            self.centers = loss['centers']
247
            tf.add_to_collection("centers", self.centers)
248
            tf.add_to_collection("loss", self.loss)
249 250
            tf.add_to_collection("prelogits", prelogits)
            self.prelogits = prelogits
251

Tiago Pereira's avatar
Tiago Pereira committed
252 253
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
254
        self.global_step = tf.train.get_or_create_global_step()
Tiago Pereira's avatar
Tiago Pereira committed
255

256 257
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
258
        self.optimizer = self.optimizer_class.minimize(self.loss, global_step=self.global_step)
259

Tiago Pereira's avatar
Tiago Pereira committed
260
        # Saving all the variables
261
        self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(),
262
                                    keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
Tiago Pereira's avatar
Tiago Pereira committed
263

264
        self.summaries_train = self.create_general_summary(self.loss, self.graph, self.label_ph)
265

266 267
        # SAving some variables
        tf.add_to_collection("global_step", self.global_step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
268

269
        tf.add_to_collection("loss", self.loss)
270
        tf.add_to_collection("graph", self.graph)
Tiago Pereira's avatar
Tiago Pereira committed
271 272
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
273

Tiago Pereira's avatar
Tiago Pereira committed
274 275
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
276

Tiago Pereira's avatar
Tiago Pereira committed
277
        tf.add_to_collection("summaries_train", self.summaries_train)
278

279
        # Same business with the validation
280
        if self.validation_data_shuffler is not None:
281 282 283 284 285
            self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
            self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)

            self.validation_graph = validation_graph

286
            if self.validate_with_embeddings:
287
                self.validation_loss = self.validation_graph
288
            else:
289 290
                #self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
                self.validation_loss = validation_loss
291

292
            self.summaries_validation = self.create_general_summary(self.validation_loss, self.validation_graph, self.validation_label_ph)
293
            tf.add_to_collection("summaries_validation", self.summaries_validation)
294

295 296 297 298
            tf.add_to_collection("validation_graph", self.validation_graph)
            tf.add_to_collection("validation_data_ph", self.validation_data_ph)
            tf.add_to_collection("validation_label_ph", self.validation_label_ph)

299
            tf.add_to_collection("validation_loss", self.validation_loss)
300
            tf.add_to_collection("summaries_validation", self.summaries_validation)
Tiago Pereira's avatar
Tiago Pereira committed
301

Tiago Pereira's avatar
Tiago Pereira committed
302
        # Creating the variables
303
        tf.local_variables_initializer().run(session=self.session)
Tiago Pereira's avatar
Tiago Pereira committed
304 305
        tf.global_variables_initializer().run(session=self.session)

306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
    def load_checkpoint(self, file_name, clear_devices=True):
        """
        Load a checkpoint

        ** Parameters **

           file_name:
                Name of the metafile to be loaded.
                If a directory is passed, the last checkpoint will be loaded

        """
        if os.path.isdir(file_name):
            checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path
            self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices)
            self.saver.restore(self.session, tf.train.latest_checkpoint(file_name))
        else:
            self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
            self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
324

325
    def load_variables_from_external_model(self, checkpoint_path, var_list):
326 327
        """
        Load a set of variables from a given model and update them in the current one
328

329
        ** Parameters **
330

331
          checkpoint_path:
332 333 334
            Name of the tensorflow model to be loaded
          var_list:
            List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
335

336
        """
337

338
        assert len(var_list)>0
339

340 341 342 343 344
        tf_varlist = []
        for v in var_list:
            tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v)

        saver = tf.train.Saver(tf_varlist)
345
        saver.restore(self.session, tf.train.latest_checkpoint(checkpoint_path))
346

347
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
348
        """
Tiago Pereira's avatar
Tiago Pereira committed
349
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
350 351 352

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
353
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
354
        """
355 356 357

        logger.info("Loading last checkpoint !!")
        self.load_checkpoint(file_name, clear_devices=True)
Tiago Pereira's avatar
Tiago Pereira committed
358 359

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
360 361
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
362 363

        self.graph = tf.get_collection("graph")[0]
364
        self.loss = tf.get_collection("loss")[0]
Tiago Pereira's avatar
Tiago Pereira committed
365 366 367 368

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
369
        self.summaries_train = tf.get_collection("summaries_train")[0]
Tiago Pereira's avatar
Tiago Pereira committed
370 371
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False
372 373 374 375

        if len(tf.get_collection("centers")) > 0:
            self.centers = tf.get_collection("centers")[0]
            self.prelogits = tf.get_collection("prelogits")[0]
376

377
        # Loading the validation bits
378
        if self.validation_data_shuffler is not None:
379 380 381 382
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

            self.validation_graph = tf.get_collection("validation_graph")[0]
            self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
383
            self.validation_label_ph = tf.get_collection("validation_label_ph")[0]
384

385
            self.validation_loss = tf.get_collection("validation_loss")[0]
386 387
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

Tiago Pereira's avatar
Tiago Pereira committed
388 389
    def __del__(self):
        tf.reset_default_graph()
390 391 392

    def get_feed_dict(self, data_shuffler):
        """
393
        Given a data shuffler prepared the dictionary to be injected in the graph
394 395

        ** Parameters **
396 397

            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
398

399
        """
400
        [data, labels] = data_shuffler.get_batch()
401

Tiago Pereira's avatar
Tiago Pereira committed
402 403
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
404 405
        return feed_dict

406
    def fit(self, step):
407 408 409 410 411 412 413 414 415
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

416
        if self.train_data_shuffler.prefetch:
417 418
            # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
            if self.centers is None:
419
                _, l, lr, summary = self.session.run([self.optimizer, self.loss,
420 421
                                                      self.learning_rate, self.summaries_train])
            else:
422
                _, l, lr, summary, _ = self.session.run([self.optimizer, self.loss,
423
                                                      self.learning_rate, self.summaries_train, self.centers])
424

425 426
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
427
            _, l, lr, summary = self.session.run([self.optimizer, self.loss,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
428
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
429

430 431
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
432

433
    def compute_validation(self, step):
Tiago Pereira's avatar
Tiago Pereira committed
434 435 436 437 438 439 440 441 442 443
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """

444
        if self.validation_data_shuffler.prefetch:
445
            l, lr, summary = self.session.run([self.validation_loss,
446 447 448
                                               self.learning_rate, self.summaries_validation])
        else:
            feed_dict = self.get_feed_dict(self.validation_data_shuffler)
449
            l, lr, summary = self.session.run([self.validation_loss,
450 451
                                               self.learning_rate, self.summaries_validation],
                                               feed_dict=feed_dict)
Tiago Pereira's avatar
Tiago Pereira committed
452

453
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
454
        self.validation_summary_writter.add_summary(summary, step)
Tiago Pereira's avatar
Tiago Pereira committed
455

456 457 458 459 460 461 462 463 464 465
    def compute_validation_embeddings(self, step):
        """
        Computes the loss in the validation set with embeddings

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
466

467
        if self.validation_data_shuffler.prefetch:
468
            embedding, labels = self.session.run([self.validation_loss, self.validation_label_ph])
469 470
        else:
            feed_dict = self.get_feed_dict(self.validation_data_shuffler)
471
            embedding, labels = self.session.run([self.validation_loss, self.validation_label_ph],
472
                                               feed_dict=feed_dict)
473

474
        accuracy = compute_embedding_accuracy(embedding, labels)
475

476 477
        summary = summary_pb2.Summary.Value(tag="accuracy", simple_value=accuracy)
        logger.info("VALIDATION Accuracy set step={0} = {1}".format(step, accuracy))
478
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)
479 480


481
    def create_general_summary(self, average_loss, output, label):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
482
        """
483
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
484
        """
485 486 487

        # Appending histograms for each trainable variables
        #for var in tf.trainable_variables():
488 489
        #for var in tf.global_variables():
        #    tf.summary.histogram(var.op.name, var)
490

491
        # Train summary
492
        tf.summary.scalar('loss', average_loss)
493
        tf.summary.scalar('lr', self.learning_rate)
494 495

        # Computing accuracy
496
        correct_prediction = tf.equal(tf.argmax(output, 1), label)
497

498
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
499
        tf.summary.scalar('accuracy', accuracy)
500
        return tf.summary.merge_all()
501

502
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
503
        """
504 505 506 507
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
508
        """
509

510
        threads = []
511
        for n in range(self.train_data_shuffler.prefetch_threads):
512
            t = threading.Thread(target=self.load_and_enqueue, args=())
513 514 515 516
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
517

518
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
519
        """
520
        Injecting data in the place holder queue
521 522 523

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
524

Tiago Pereira's avatar
Tiago Pereira committed
525
        """
526
        while not self.thread_pool.should_stop():
527
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
528

529 530
            data_ph = self.train_data_shuffler("data", from_queue=False)
            label_ph = self.train_data_shuffler("label", from_queue=False)
531

532 533 534 535
            feed_dict = {data_ph: train_data,
                         label_ph: train_labels}

            self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
536

537