Trainer.py 19.5 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:25:22 CEST

import tensorflow as tf
7 8 9
import threading
import os
import bob.io.base
10
import bob.core
11
from ..analyzers import SoftmaxAnalizer
12
from tensorflow.core.framework import summary_pb2
13
import time
14
from bob.learn.tensorflow.datashuffler import OnlineSampling, TFRecord
15
from bob.learn.tensorflow.utils.session import Session
16
from bob.learn.tensorflow.utils import compute_embedding_accuracy
17
from .learning_rate import constant
18
import time
19

20 21 22 23 24
#logger = bob.core.log.setup("bob.learn.tensorflow")

import logging
logger = logging.getLogger("bob.learn")

25

26 27 28 29 30 31
class Trainer(object):
    """
    One graph trainer.
    Use this trainer when your CNN is composed by one graph

    **Parameters**
32

Tiago Pereira's avatar
Tiago Pereira committed
33 34
    train_data_shuffler:
      The data shuffler used for batching data for training
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35

Tiago Pereira's avatar
Tiago Pereira committed
36
    iterations:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
37
      Maximum number of iterations
38

Tiago Pereira's avatar
Tiago Pereira committed
39
    snapshot:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
40
      Will take a snapshot of the network at every `n` iterations
41

Tiago Pereira's avatar
Tiago Pereira committed
42 43
    validation_snapshot:
      Test with validation each `n` iterations
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44 45 46 47

    analizer:
      Neural network analizer :py:mod:`bob.learn.tensorflow.analyzers`

Tiago Pereira's avatar
Tiago Pereira committed
48 49 50
    temp_dir: str
      The output directory

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
51
    verbosity_level:
52 53

    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
54

55
    def __init__(self,
Tiago Pereira's avatar
Tiago Pereira committed
56
                 train_data_shuffler,
57
                 validation_data_shuffler=None,
58
                 validate_with_embeddings=False,
59

60 61
                 ###### training options ##########
                 iterations=5000,
62
                 snapshot=1000,
63
                 validation_snapshot=2000,#2000,
64
                 keep_checkpoint_every_n_hours=2,
65 66

                 ## Analizer
67
                 analizer=SoftmaxAnalizer(),
68

Tiago Pereira's avatar
Tiago Pereira committed
69 70
                 # Temporatu dir
                 temp_dir="cnn",
71

72
                 verbosity_level=2):
73

Tiago Pereira's avatar
Tiago Pereira committed
74
        self.train_data_shuffler = train_data_shuffler
75

76 77
        self.temp_dir = temp_dir

78 79
        self.iterations = iterations
        self.snapshot = snapshot
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
80
        self.validation_snapshot = validation_snapshot
81
        self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
82

83 84 85
        # Training variables used in the fit
        self.summaries_train = None
        self.train_summary_writter = None
86
        self.thread_pool = None
87
        self.centers = None
88 89 90

        # Validation data
        self.validation_summary_writter = None
91
        self.summaries_validation = None
92
        self.validation_data_shuffler = validation_data_shuffler
93

94 95
        # Analizer
        self.analizer = analizer
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
96
        self.global_step = None
97

98
        self.session = None
99

Tiago Pereira's avatar
Tiago Pereira committed
100
        self.graph = None
101
        self.validation_graph = None
102
        self.prelogits = None
103
                
Tiago Pereira's avatar
Tiago Pereira committed
104
        self.loss = None
105
        
106 107
        self.validation_predictor = None  
        self.validate_with_embeddings = validate_with_embeddings      
108
        
Tiago Pereira's avatar
Tiago Pereira committed
109 110
        self.optimizer_class = None
        self.learning_rate = None
111

Tiago Pereira's avatar
Tiago Pereira committed
112 113
        # Training variables used in the fit
        self.optimizer = None
114
        
Tiago Pereira's avatar
Tiago Pereira committed
115 116
        self.data_ph = None
        self.label_ph = None
117 118 119 120
        
        self.validation_data_ph = None
        self.validation_label_ph = None
        
Tiago Pereira's avatar
Tiago Pereira committed
121 122
        self.saver = None

123 124
        bob.core.log.set_verbosity_level(logger, verbosity_level)

Tiago Pereira's avatar
Tiago Pereira committed
125 126 127
        # Creating the session
        self.session = Session.instance(new=True).session
        self.from_scratch = True
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        
    def train(self):
        """
        Train the network        
        Here we basically have the loop for that takes your graph and do a sequence of session.run
        """

        # Creating directories
        bob.io.base.create_directories_safe(self.temp_dir)
        logger.info("Initializing !!")

        # Loading a pretrained model
        if self.from_scratch:
            start_step = 0
        else:
            start_step = self.global_step.eval(session=self.session)

        # TODO: Put this back as soon as possible
        #if isinstance(train_data_shuffler, OnlineSampling):
        #    train_data_shuffler.set_feature_extractor(self.architecture, session=self.session)

        # Start a thread to enqueue data asynchronously, and hide I/O latency.        
        if self.train_data_shuffler.prefetch:
            self.thread_pool = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=self.thread_pool, sess=self.session)
            # In case you have your own queue
            if not isinstance(self.train_data_shuffler, TFRecord):
                threads = self.start_thread()

        # Bootstrapping the summary writters
        self.train_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'train'), self.session.graph)
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter = tf.summary.FileWriter(os.path.join(self.temp_dir, 'validation'),
                                                                    self.session.graph)

        ######################### Loop for #################
        for step in range(start_step, start_step+self.iterations):
            # Run fit in the graph
            start = time.time()
            self.fit(step)
            end = time.time()

            summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
            self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)

            # Running validation
            if self.validation_data_shuffler is not None and step % self.validation_snapshot == 0:
175 176 177 178
                if self.validate_with_embeddings:
                    self.compute_validation_embeddings(step)
                else:
                    self.compute_validation(step)
179 180

            # Taking snapshot
181
            if step % self.snapshot == 0:            
182 183 184 185 186 187
                logger.info("Taking snapshot")
                path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
                self.saver.save(self.session, path, global_step=step)

        # Running validation for the last time
        if self.validation_data_shuffler is not None:
188 189 190 191 192
            if self.validate_with_embeddings:
                self.compute_validation_embeddings(step)
            else:
                self.compute_validation(step)
            
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
            
        logger.info("Training finally finished")

        self.train_summary_writter.close()
        if self.validation_data_shuffler is not None:
            self.validation_summary_writter.close()

        # Saving the final network
        path = os.path.join(self.temp_dir, 'model.ckp')
        self.saver.save(self.session, path)

        if self.train_data_shuffler.prefetch or isinstance(self.train_data_shuffler, TFRecord):
            # now they should definetely stop
            self.thread_pool.request_stop()
            #if not isinstance(self.train_data_shuffler, TFRecord):
            #    self.thread_pool.join(threads)        

Tiago Pereira's avatar
Tiago Pereira committed
210 211
    def create_network_from_scratch(self,
                                    graph,
212
                                    validation_graph=None,
Tiago Pereira's avatar
Tiago Pereira committed
213 214
                                    optimizer=tf.train.AdamOptimizer(),
                                    loss=None,
215

Tiago Pereira's avatar
Tiago Pereira committed
216 217
                                    # Learning rate
                                    learning_rate=None,
218
                                    prelogits=None
Tiago Pereira's avatar
Tiago Pereira committed
219 220
                                    ):

Tiago Pereira's avatar
Tiago Pereira committed
221 222
        """
        Prepare all the tensorflow variables before training.
223
        
Tiago Pereira's avatar
Tiago Pereira committed
224
        **Parameters**
225

Tiago Pereira's avatar
Tiago Pereira committed
226
            graph: Input graph for training
227

Tiago Pereira's avatar
Tiago Pereira committed
228
            optimizer: Solver
229

Tiago Pereira's avatar
Tiago Pereira committed
230
            loss: Loss function
231

Tiago Pereira's avatar
Tiago Pereira committed
232 233
            learning_rate: Learning rate
        """
234
        # Getting the pointer to the placeholders
235 236
        self.data_ph = self.train_data_shuffler("data", from_queue=True)
        self.label_ph = self.train_data_shuffler("label", from_queue=True)
237
                
Tiago Pereira's avatar
Tiago Pereira committed
238
        self.graph = graph
239
        self.loss = loss        
240

241 242 243
        # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
        self.centers = None
        if prelogits is not None:
244 245
            self.loss = loss['loss']
            self.centers = loss['centers']
246
            tf.add_to_collection("centers", self.centers)
247
            tf.add_to_collection("loss", self.loss)
248 249
            tf.add_to_collection("prelogits", prelogits)
            self.prelogits = prelogits
250

Tiago Pereira's avatar
Tiago Pereira committed
251 252
        self.optimizer_class = optimizer
        self.learning_rate = learning_rate
253
        self.global_step = tf.contrib.framework.get_or_create_global_step()
Tiago Pereira's avatar
Tiago Pereira committed
254

255 256
        # Preparing the optimizer
        self.optimizer_class._learning_rate = self.learning_rate
257
        self.optimizer = self.optimizer_class.minimize(self.loss, global_step=self.global_step)
258

Tiago Pereira's avatar
Tiago Pereira committed
259
        # Saving all the variables
260 261
        self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), 
                                    keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours)
Tiago Pereira's avatar
Tiago Pereira committed
262

263
        self.summaries_train = self.create_general_summary(self.loss, self.graph, self.label_ph)
264

265 266
        # SAving some variables
        tf.add_to_collection("global_step", self.global_step)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
267

268
        tf.add_to_collection("loss", self.loss)
269
        tf.add_to_collection("graph", self.graph)
Tiago Pereira's avatar
Tiago Pereira committed
270 271
        tf.add_to_collection("data_ph", self.data_ph)
        tf.add_to_collection("label_ph", self.label_ph)
272

Tiago Pereira's avatar
Tiago Pereira committed
273 274
        tf.add_to_collection("optimizer", self.optimizer)
        tf.add_to_collection("learning_rate", self.learning_rate)
275

Tiago Pereira's avatar
Tiago Pereira committed
276
        tf.add_to_collection("summaries_train", self.summaries_train)
277

278
        # Same business with the validation
279
        if self.validation_data_shuffler is not None:
280 281 282 283 284
            self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
            self.validation_label_ph = self.validation_data_shuffler("label", from_queue=True)

            self.validation_graph = validation_graph

285
            if self.validate_with_embeddings:            
286
                self.validation_predictor = self.validation_graph
287
            else:            
288
                self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
289 290 291 292 293 294 295 296 297 298

            self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph)
            tf.add_to_collection("summaries_validation", self.summaries_validation)
            
            tf.add_to_collection("validation_graph", self.validation_graph)
            tf.add_to_collection("validation_data_ph", self.validation_data_ph)
            tf.add_to_collection("validation_label_ph", self.validation_label_ph)

            tf.add_to_collection("validation_predictor", self.validation_predictor)
            tf.add_to_collection("summaries_validation", self.summaries_validation)
Tiago Pereira's avatar
Tiago Pereira committed
299

Tiago Pereira's avatar
Tiago Pereira committed
300
        # Creating the variables
301
        tf.local_variables_initializer().run(session=self.session)
Tiago Pereira's avatar
Tiago Pereira committed
302 303
        tf.global_variables_initializer().run(session=self.session)

304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
    def load_checkpoint(self, file_name, clear_devices=True):
        """
        Load a checkpoint

        ** Parameters **

           file_name:
                Name of the metafile to be loaded.
                If a directory is passed, the last checkpoint will be loaded

        """
        if os.path.isdir(file_name):
            checkpoint_path = tf.train.get_checkpoint_state(file_name).model_checkpoint_path
            self.saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=clear_devices)
            self.saver.restore(self.session, tf.train.latest_checkpoint(file_name))
        else:
            self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
            self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
322
            
323
    def load_variables_from_external_model(self, checkpoint_path, var_list):
324 325 326 327 328
        """
        Load a set of variables from a given model and update them in the current one
        
        ** Parameters **
        
329
          checkpoint_path:
330 331 332 333 334 335 336 337 338 339 340 341 342
            Name of the tensorflow model to be loaded
          var_list:
            List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
        
        """
        
        assert len(var_list)>0
        
        tf_varlist = []
        for v in var_list:
            tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v)

        saver = tf.train.Saver(tf_varlist)
343
        saver.restore(self.session, tf.train.latest_checkpoint(checkpoint_path))
344

345
    def create_network_from_file(self, file_name, clear_devices=True):
Tiago Pereira's avatar
Tiago Pereira committed
346
        """
Tiago Pereira's avatar
Tiago Pereira committed
347
        Bootstrap a graph from a checkpoint
Tiago Pereira's avatar
Tiago Pereira committed
348 349 350

         ** Parameters **

Tiago Pereira's avatar
Tiago Pereira committed
351
           file_name: Name of of the checkpoing
Tiago Pereira's avatar
Tiago Pereira committed
352
        """
353 354 355

        logger.info("Loading last checkpoint !!")
        self.load_checkpoint(file_name, clear_devices=True)
Tiago Pereira's avatar
Tiago Pereira committed
356 357

        # Loading training graph
Tiago Pereira's avatar
Tiago Pereira committed
358 359
        self.data_ph = tf.get_collection("data_ph")[0]
        self.label_ph = tf.get_collection("label_ph")[0]
Tiago Pereira's avatar
Tiago Pereira committed
360 361

        self.graph = tf.get_collection("graph")[0]
362
        self.loss = tf.get_collection("loss")[0]
Tiago Pereira's avatar
Tiago Pereira committed
363 364 365 366

        # Loding other elements
        self.optimizer = tf.get_collection("optimizer")[0]
        self.learning_rate = tf.get_collection("learning_rate")[0]
367
        self.summaries_train = tf.get_collection("summaries_train")[0]        
Tiago Pereira's avatar
Tiago Pereira committed
368 369
        self.global_step = tf.get_collection("global_step")[0]
        self.from_scratch = False
370 371 372 373

        if len(tf.get_collection("centers")) > 0:
            self.centers = tf.get_collection("centers")[0]
            self.prelogits = tf.get_collection("prelogits")[0]
374 375
        
        # Loading the validation bits
376
        if self.validation_data_shuffler is not None:
377 378 379 380
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

            self.validation_graph = tf.get_collection("validation_graph")[0]
            self.validation_data_ph = tf.get_collection("validation_data_ph")[0]
381
            self.validation_label_ph = tf.get_collection("validation_label_ph")[0]
382 383 384 385

            self.validation_predictor = tf.get_collection("validation_predictor")[0]
            self.summaries_validation = tf.get_collection("summaries_validation")[0]

Tiago Pereira's avatar
Tiago Pereira committed
386 387
    def __del__(self):
        tf.reset_default_graph()
388 389 390

    def get_feed_dict(self, data_shuffler):
        """
391
        Given a data shuffler prepared the dictionary to be injected in the graph
392 393

        ** Parameters **
394 395

            data_shuffler: Data shuffler :py:class:`bob.learn.tensorflow.datashuffler.Base`
396

397
        """
398
        [data, labels] = data_shuffler.get_batch()
399

Tiago Pereira's avatar
Tiago Pereira committed
400 401
        feed_dict = {self.data_ph: data,
                     self.label_ph: labels}
402 403
        return feed_dict

404
    def fit(self, step):
405 406 407 408 409 410 411 412 413
        """
        Run one iteration (`forward` and `backward`)

        ** Parameters **
            session: Tensorflow session
            step: Iteration number

        """

414
        if self.train_data_shuffler.prefetch:
415 416
            # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT        
            if self.centers is None:            
417
                _, l, lr, summary = self.session.run([self.optimizer, self.loss,
418 419
                                                      self.learning_rate, self.summaries_train])
            else:
420
                _, l, lr, summary, _ = self.session.run([self.optimizer, self.loss,
421 422
                                                      self.learning_rate, self.summaries_train, self.centers])
            
423 424
        else:
            feed_dict = self.get_feed_dict(self.train_data_shuffler)
425
            _, l, lr, summary = self.session.run([self.optimizer, self.loss,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
426
                                                  self.learning_rate, self.summaries_train], feed_dict=feed_dict)
427

428 429
        logger.info("Loss training set step={0} = {1}".format(step, l))
        self.train_summary_writter.add_summary(summary, step)
430

431
    def compute_validation(self, step):
Tiago Pereira's avatar
Tiago Pereira committed
432 433 434 435 436 437 438 439 440 441
        """
        Computes the loss in the validation set

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """

442 443 444 445 446 447 448 449
        if self.validation_data_shuffler.prefetch:
            l, lr, summary = self.session.run([self.validation_predictor,
                                               self.learning_rate, self.summaries_validation])
        else:
            feed_dict = self.get_feed_dict(self.validation_data_shuffler)
            l, lr, summary = self.session.run([self.validation_predictor,
                                               self.learning_rate, self.summaries_validation],
                                               feed_dict=feed_dict)
Tiago Pereira's avatar
Tiago Pereira committed
450

451 452
        logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
        self.validation_summary_writter.add_summary(summary, step)               
Tiago Pereira's avatar
Tiago Pereira committed
453

454 455 456 457 458 459 460 461 462 463
    def compute_validation_embeddings(self, step):
        """
        Computes the loss in the validation set with embeddings

        ** Parameters **
            session: Tensorflow session
            data_shuffler: The data shuffler to be used
            step: Iteration number

        """
464
        
465 466 467 468 469 470 471 472 473 474 475 476 477 478
        if self.validation_data_shuffler.prefetch:
            embedding, labels = self.session.run([self.validation_predictor, self.validation_label_ph])
        else:
            feed_dict = self.get_feed_dict(self.validation_data_shuffler)
            embedding, labels = self.session.run([self.validation_predictor, self.validation_label_ph],
                                               feed_dict=feed_dict)
                                               
        accuracy = compute_embedding_accuracy(embedding, labels)
        
        summary = summary_pb2.Summary.Value(tag="accuracy", simple_value=accuracy)
        logger.info("VALIDATION Accuracy set step={0} = {1}".format(step, accuracy))
        self.validation_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)               


479
    def create_general_summary(self, average_loss, output, label):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
480
        """
481
        Creates a simple tensorboard summary with the value of the loss and learning rate
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
482
        """
483 484 485

        # Appending histograms for each trainable variables
        #for var in tf.trainable_variables():
486 487
        #for var in tf.global_variables():
        #    tf.summary.histogram(var.op.name, var)
488
        
489
        # Train summary
490
        tf.summary.scalar('loss', average_loss)
491
        tf.summary.scalar('lr', self.learning_rate)        
492 493

        # Computing accuracy
494
        correct_prediction = tf.equal(tf.argmax(output, 1), label)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
495
        
496 497
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)        
498
        return tf.summary.merge_all()
499

500
    def start_thread(self):
Tiago Pereira's avatar
Tiago Pereira committed
501
        """
502 503 504 505
        Start pool of threads for pre-fetching

        **Parameters**
          session: Tensorflow session
Tiago Pereira's avatar
Tiago Pereira committed
506
        """
507

508
        threads = []
509
        for n in range(self.train_data_shuffler.prefetch_threads):
510
            t = threading.Thread(target=self.load_and_enqueue, args=())
511 512 513 514
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads
515

516
    def load_and_enqueue(self):
Tiago Pereira's avatar
Tiago Pereira committed
517
        """
518
        Injecting data in the place holder queue
519 520 521

        **Parameters**
          session: Tensorflow session
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
522

Tiago Pereira's avatar
Tiago Pereira committed
523
        """
524
        while not self.thread_pool.should_stop():
525
            [train_data, train_labels] = self.train_data_shuffler.get_batch()
526

527 528
            data_ph = self.train_data_shuffler("data", from_queue=False)
            label_ph = self.train_data_shuffler("label", from_queue=False)
529

530 531 532 533
            feed_dict = {data_ph: train_data,
                         label_ph: train_labels}

            self.session.run(self.train_data_shuffler.enqueue_op, feed_dict=feed_dict)
534

535